diff --git a/.github/workflows/backend-ci.yml b/.github/workflows/backend-ci.yml
index 4fd22aff..d21d0684 100644
--- a/.github/workflows/backend-ci.yml
+++ b/.github/workflows/backend-ci.yml
@@ -11,8 +11,8 @@ jobs:
test:
runs-on: ubuntu-latest
steps:
- - uses: actions/checkout@v4
- - uses: actions/setup-go@v5
+ - uses: actions/checkout@v6
+ - uses: actions/setup-go@v6
with:
go-version-file: backend/go.mod
check-latest: false
@@ -30,8 +30,8 @@ jobs:
golangci-lint:
runs-on: ubuntu-latest
steps:
- - uses: actions/checkout@v4
- - uses: actions/setup-go@v5
+ - uses: actions/checkout@v6
+ - uses: actions/setup-go@v6
with:
go-version-file: backend/go.mod
check-latest: false
@@ -43,5 +43,5 @@ jobs:
uses: golangci/golangci-lint-action@v9
with:
version: v2.7
- args: --timeout=5m
+ args: --timeout=30m
working-directory: backend
\ No newline at end of file
diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml
index 50bb73e0..a1c6aa23 100644
--- a/.github/workflows/release.yml
+++ b/.github/workflows/release.yml
@@ -31,7 +31,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout
- uses: actions/checkout@v4
+ uses: actions/checkout@v6
- name: Update VERSION file
run: |
@@ -45,7 +45,7 @@ jobs:
echo "Updated VERSION file to: $VERSION"
- name: Upload VERSION artifact
- uses: actions/upload-artifact@v4
+ uses: actions/upload-artifact@v7
with:
name: version-file
path: backend/cmd/server/VERSION
@@ -55,7 +55,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout
- uses: actions/checkout@v4
+ uses: actions/checkout@v6
- name: Setup pnpm
uses: pnpm/action-setup@v4
@@ -63,7 +63,7 @@ jobs:
version: 9
- name: Setup Node.js
- uses: actions/setup-node@v4
+ uses: actions/setup-node@v6
with:
node-version: '20'
cache: 'pnpm'
@@ -78,7 +78,7 @@ jobs:
working-directory: frontend
- name: Upload frontend artifact
- uses: actions/upload-artifact@v4
+ uses: actions/upload-artifact@v7
with:
name: frontend-dist
path: backend/internal/web/dist/
@@ -89,25 +89,25 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout
- uses: actions/checkout@v4
+ uses: actions/checkout@v6
with:
fetch-depth: 0
ref: ${{ github.event.inputs.tag || github.ref }}
- name: Download VERSION artifact
- uses: actions/download-artifact@v4
+ uses: actions/download-artifact@v8
with:
name: version-file
path: backend/cmd/server/
- name: Download frontend artifact
- uses: actions/download-artifact@v4
+ uses: actions/download-artifact@v8
with:
name: frontend-dist
path: backend/internal/web/dist/
- name: Setup Go
- uses: actions/setup-go@v5
+ uses: actions/setup-go@v6
with:
go-version-file: backend/go.mod
check-latest: false
@@ -173,7 +173,7 @@ jobs:
run: echo "owner=$(echo '${{ github.repository_owner }}' | tr '[:upper:]' '[:lower:]')" >> $GITHUB_OUTPUT
- name: Run GoReleaser
- uses: goreleaser/goreleaser-action@v6
+ uses: goreleaser/goreleaser-action@v7
with:
version: '~> v2'
args: release --clean --skip=validate ${{ env.SIMPLE_RELEASE == 'true' && '--config=.goreleaser.simple.yaml' || '' }}
@@ -188,7 +188,7 @@ jobs:
# Update DockerHub description
- name: Update DockerHub description
if: ${{ env.SIMPLE_RELEASE != 'true' && env.DOCKERHUB_USERNAME != '' }}
- uses: peter-evans/dockerhub-description@v4
+ uses: peter-evans/dockerhub-description@v5
env:
DOCKERHUB_USERNAME: ${{ secrets.DOCKERHUB_USERNAME }}
with:
diff --git a/.github/workflows/security-scan.yml b/.github/workflows/security-scan.yml
index fd0c7a41..db922509 100644
--- a/.github/workflows/security-scan.yml
+++ b/.github/workflows/security-scan.yml
@@ -12,10 +12,11 @@ permissions:
jobs:
backend-security:
runs-on: ubuntu-latest
+ timeout-minutes: 15
steps:
- - uses: actions/checkout@v4
+ - uses: actions/checkout@v6
- name: Set up Go
- uses: actions/setup-go@v5
+ uses: actions/setup-go@v6
with:
go-version-file: backend/go.mod
check-latest: false
@@ -28,22 +29,17 @@ jobs:
run: |
go install golang.org/x/vuln/cmd/govulncheck@latest
govulncheck ./...
- - name: Run gosec
- working-directory: backend
- run: |
- go install github.com/securego/gosec/v2/cmd/gosec@latest
- gosec -conf .gosec.json -severity high -confidence high ./...
frontend-security:
runs-on: ubuntu-latest
steps:
- - uses: actions/checkout@v4
+ - uses: actions/checkout@v6
- name: Set up pnpm
uses: pnpm/action-setup@v4
with:
version: 9
- name: Set up Node.js
- uses: actions/setup-node@v4
+ uses: actions/setup-node@v6
with:
node-version: '20'
cache: 'pnpm'
diff --git a/.gitignore b/.gitignore
index 515ce84f..297c1d6f 100644
--- a/.gitignore
+++ b/.gitignore
@@ -116,13 +116,12 @@ backend/.installed
# ===================
tests
CLAUDE.md
-AGENTS.md
.claude
scripts
.code-review-state
-openspec/
+#openspec/
code-reviews/
-AGENTS.md
+#AGENTS.md
backend/cmd/server/server
deploy/docker-compose.override.yml
.gocache/
@@ -132,4 +131,5 @@ docs/*
.codex/
frontend/coverage/
aicodex
+output/
diff --git a/AGENTS.md b/AGENTS.md
new file mode 100644
index 00000000..bb5bb465
--- /dev/null
+++ b/AGENTS.md
@@ -0,0 +1,105 @@
+# Repository Guidelines
+
+## Project Structure & Module Organization
+- `backend/`: Go service. `cmd/server` is the entrypoint, `internal/` contains handlers/services/repositories/server wiring, `ent/` holds Ent schemas and generated ORM code, `migrations/` stores DB migrations, and `internal/web/dist/` is the embedded frontend build output.
+- `frontend/`: Vue 3 + TypeScript app. Main folders are `src/api`, `src/components`, `src/views`, `src/stores`, `src/composables`, `src/utils`, and test files in `src/**/__tests__`.
+- `deploy/`: Docker and deployment assets (`docker-compose*.yml`, `.env.example`, `config.example.yaml`).
+- `openspec/`: Spec-driven change docs (`changes//{proposal,design,tasks}.md`).
+- `tools/`: Utility scripts (security/perf checks).
+
+## Build, Test, and Development Commands
+```bash
+make build # Build backend + frontend
+make test # Backend tests + frontend lint/typecheck
+cd backend && make build # Build backend binary
+cd backend && make test-unit # Go unit tests
+cd backend && make test-integration # Go integration tests
+cd backend && make test # go test ./... + golangci-lint
+cd frontend && pnpm install --frozen-lockfile
+cd frontend && pnpm dev # Vite dev server
+cd frontend && pnpm build # Type-check + production build
+cd frontend && pnpm test:run # Vitest run
+cd frontend && pnpm test:coverage # Vitest + coverage report
+python3 tools/secret_scan.py # Secret scan
+```
+
+## Coding Style & Naming Conventions
+- Go: format with `gofmt`; lint with `golangci-lint` (`backend/.golangci.yml`).
+- Respect layering: `internal/service` and `internal/handler` must not import `internal/repository`, `gorm`, or `redis` directly (enforced by depguard).
+- Frontend: Vue SFC + TypeScript, 2-space indentation, ESLint rules from `frontend/.eslintrc.cjs`.
+- Naming: components use `PascalCase.vue`, composables use `useXxx.ts`, Go tests use `*_test.go`, frontend tests use `*.spec.ts`.
+
+## Go & Frontend Development Standards
+- Control branch complexity: `if` nesting must not exceed 3 levels. Refactor with guard clauses, early returns, helper functions, or strategy maps when deeper logic appears.
+- JSON hot-path rule: for read-only/partial-field extraction, prefer `gjson` over full `encoding/json` struct unmarshal to reduce allocations and improve latency.
+- Exception rule: if full schema validation or typed writes are required, `encoding/json` is allowed, but PR must explain why `gjson` is not suitable.
+
+### Go Performance Rules
+- Optimization workflow rule: benchmark/profile first, then optimize. Use `go test -bench`, `go tool pprof`, and runtime diagnostics before changing hot-path code.
+- For hot functions, run escape analysis (`go build -gcflags=all='-m -m'`) and prioritize stack allocation where reasonable.
+- Every external I/O path must use `context.Context` with explicit timeout/cancel.
+- When creating derived contexts (`WithTimeout` / `WithDeadline`), always `defer cancel()` to release resources.
+- Preallocate slices/maps when size can be estimated (`make([]T, 0, n)`, `make(map[K]V, n)`).
+- Avoid unnecessary allocations in loops; reuse buffers and prefer `strings.Builder`/`bytes.Buffer`.
+- Prohibit N+1 query patterns; batch DB/Redis operations and verify indexes for new query paths.
+- For hot-path changes, include benchmark or latency comparison evidence (e.g., `go test -bench` before/after).
+- Keep goroutine growth bounded (worker pool/semaphore), and avoid unbounded fan-out.
+- Lock minimization rule: if a lock can be avoided, do not use a lock. Prefer ownership transfer (channel), sharding, immutable snapshots, copy-on-write, or atomic operations to reduce contention.
+- When locks are unavoidable, keep critical sections minimal, avoid nested locks, and document why lock-free alternatives are not feasible.
+- Follow `sync` guidance: prefer channels for higher-level synchronization; use low-level mutex primitives only where necessary.
+- Avoid reflection and `interface{}`-heavy conversions in hot paths; use typed structs/functions.
+- Use `sync.Pool` only when benchmark proves allocation reduction; remove if no measurable gain.
+- Avoid repeated `time.Now()`/`fmt.Sprintf` in tight loops; hoist or cache when possible.
+- For stable high-traffic binaries, maintain representative `default.pgo` profiles and keep `go build -pgo=auto` enabled.
+
+### Data Access & Cache Rules
+- Every new/changed SQL query must be checked with `EXPLAIN` (or `EXPLAIN ANALYZE` in staging) and include index rationale in PR.
+- Default to keyset pagination for large tables; avoid deep `OFFSET` scans on hot endpoints.
+- Query only required columns; prohibit broad `SELECT *` in latency-sensitive paths.
+- Keep transactions short; never perform external RPC/network calls inside DB transactions.
+- Connection pool must be explicitly tuned and observed via `DB.Stats` (`SetMaxOpenConns`, `SetMaxIdleConns`, `SetConnMaxIdleTime`, `SetConnMaxLifetime`).
+- Avoid overly small `MaxOpenConns` that can turn DB access into lock/semaphore bottlenecks.
+- Cache keys must be versioned (e.g., `user_usage:v2:{id}`) and TTL should include jitter to avoid thundering herd.
+- Use request coalescing (`singleflight` or equivalent) for high-concurrency cache miss paths.
+
+### Frontend Performance Rules
+- Route-level and heavy-module code splitting is required; lazy-load non-critical views/components.
+- API requests must support cancellation and deduplication; use debounce/throttle for search-like inputs.
+- Minimize unnecessary reactivity: avoid deep watch chains when computed/cache can solve it.
+- Prefer stable props and selective rendering controls (`v-once`, `v-memo`) for expensive subtrees when data is static or keyed.
+- Large data rendering must use pagination or virtualization (especially tables/lists >200 rows).
+- Move expensive CPU work off the main thread (Web Worker) or chunk tasks to avoid UI blocking.
+- Keep bundle growth controlled; avoid adding heavy dependencies without clear ROI and alternatives review.
+- Avoid expensive inline computations in templates; move to cached `computed` selectors.
+- Keep state normalized; avoid duplicated derived state across multiple stores/components.
+- Load charts/editors/export libraries on demand only (`dynamic import`) instead of app-entry import.
+- Core Web Vitals targets (p75): `LCP <= 2.5s`, `INP <= 200ms`, `CLS <= 0.1`.
+- Main-thread task budget: keep individual tasks below ~50ms; split long tasks and yield between chunks.
+- Enforce frontend budgets in CI (Lighthouse CI with `budget.json`) for critical routes.
+
+### Performance Budget & PR Evidence
+- Performance budget is mandatory for hot-path PRs: backend p95/p99 latency and CPU/memory must not regress by more than 5% versus baseline.
+- Frontend budget: new route-level JS should not increase by more than 30KB gzip without explicit approval.
+- For any gateway/protocol hot path, attach a reproducible benchmark command and results (input size, concurrency, before/after table).
+- Profiling evidence is required for major optimizations (`pprof`, flamegraph, browser performance trace, or bundle analyzer output).
+
+### Quality Gate
+- Any changed code must include new or updated unit tests.
+- Coverage must stay above 85% (global frontend threshold and no regressions for touched backend modules).
+- If any rule is intentionally violated, document reason, risk, and mitigation in the PR description.
+
+## Testing Guidelines
+- Backend suites: `go test -tags=unit ./...`, `go test -tags=integration ./...`, and e2e where relevant.
+- Frontend uses Vitest (`jsdom`); keep tests near modules (`__tests__`) or as `*.spec.ts`.
+- Enforce unit-test and coverage rules defined in `Quality Gate`.
+- Before opening a PR, run `make test` plus targeted tests for touched areas.
+
+## Commit & Pull Request Guidelines
+- Follow Conventional Commits: `feat(scope): ...`, `fix(scope): ...`, `chore(scope): ...`, `docs(scope): ...`.
+- PRs should include a clear summary, linked issue/spec, commands run for verification, and screenshots/GIFs for UI changes.
+- For behavior/API changes, add or update `openspec/changes/...` artifacts.
+- If dependencies change, commit `frontend/pnpm-lock.yaml` in the same PR.
+
+## Security & Configuration Tips
+- Use `deploy/.env.example` and `deploy/config.example.yaml` as templates; do not commit real credentials.
+- Set stable `JWT_SECRET`, `TOTP_ENCRYPTION_KEY`, and strong database passwords outside local dev.
diff --git a/Dockerfile b/Dockerfile
index 645465f1..1493e8a7 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -8,7 +8,7 @@
ARG NODE_IMAGE=node:24-alpine
ARG GOLANG_IMAGE=golang:1.25.7-alpine
-ARG ALPINE_IMAGE=alpine:3.20
+ARG ALPINE_IMAGE=alpine:3.21
ARG GOPROXY=https://goproxy.cn,direct
ARG GOSUMDB=sum.golang.google.cn
@@ -68,6 +68,7 @@ RUN VERSION_VALUE="${VERSION}" && \
CGO_ENABLED=0 GOOS=linux go build \
-tags embed \
-ldflags="-s -w -X main.Version=${VERSION_VALUE} -X main.Commit=${COMMIT} -X main.Date=${DATE_VALUE} -X main.BuildType=release" \
+ -trimpath \
-o /app/sub2api \
./cmd/server
@@ -85,7 +86,6 @@ LABEL org.opencontainers.image.source="https://github.com/Wei-Shaw/sub2api"
RUN apk add --no-cache \
ca-certificates \
tzdata \
- curl \
&& rm -rf /var/cache/apk/*
# Create non-root user
@@ -95,11 +95,12 @@ RUN addgroup -g 1000 sub2api && \
# Set working directory
WORKDIR /app
-# Copy binary from builder
-COPY --from=backend-builder /app/sub2api /app/sub2api
+# Copy binary/resources with ownership to avoid extra full-layer chown copy
+COPY --from=backend-builder --chown=sub2api:sub2api /app/sub2api /app/sub2api
+COPY --from=backend-builder --chown=sub2api:sub2api /app/backend/resources /app/resources
# Create data directory
-RUN mkdir -p /app/data && chown -R sub2api:sub2api /app
+RUN mkdir -p /app/data && chown sub2api:sub2api /app/data
# Switch to non-root user
USER sub2api
@@ -109,7 +110,7 @@ EXPOSE 8080
# Health check
HEALTHCHECK --interval=30s --timeout=10s --start-period=10s --retries=3 \
- CMD curl -f http://localhost:${SERVER_PORT:-8080}/health || exit 1
+ CMD wget -q -T 5 -O /dev/null http://localhost:${SERVER_PORT:-8080}/health || exit 1
# Run the application
ENTRYPOINT ["/app/sub2api"]
diff --git a/Makefile b/Makefile
index b97404eb..fd6a5a9a 100644
--- a/Makefile
+++ b/Makefile
@@ -1,4 +1,4 @@
-.PHONY: build build-backend build-frontend test test-backend test-frontend secret-scan
+.PHONY: build build-backend build-frontend build-datamanagementd test test-backend test-frontend test-datamanagementd secret-scan
# 一键编译前后端
build: build-backend build-frontend
@@ -11,6 +11,10 @@ build-backend:
build-frontend:
@pnpm --dir frontend run build
+# 编译 datamanagementd(宿主机数据管理进程)
+build-datamanagementd:
+ @cd datamanagement && go build -o datamanagementd ./cmd/datamanagementd
+
# 运行测试(后端 + 前端)
test: test-backend test-frontend
@@ -21,5 +25,8 @@ test-frontend:
@pnpm --dir frontend run lint:check
@pnpm --dir frontend run typecheck
+test-datamanagementd:
+ @cd datamanagement && go test ./...
+
secret-scan:
@python3 tools/secret_scan.py
diff --git a/README.md b/README.md
index a5f680bf..1e2f2290 100644
--- a/README.md
+++ b/README.md
@@ -54,6 +54,7 @@ Sub2API is an AI API gateway platform designed to distribute and manage API quot
## Documentation
- Dependency Security: `docs/dependency-security.md`
+- Admin Payment Integration API: `docs/ADMIN_PAYMENT_INTEGRATION_API.md`
---
diff --git a/README_CN.md b/README_CN.md
index ea35a19d..9da089b7 100644
--- a/README_CN.md
+++ b/README_CN.md
@@ -62,8 +62,6 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅(
- 当请求包含 `function_call_output` 时,需要携带 `previous_response_id`,或在 `input` 中包含带 `call_id` 的 `tool_call`/`function_call`,或带非空 `id` 且与 `function_call_output.call_id` 匹配的 `item_reference`。
- 若依赖上游历史记录,网关会强制 `store=true` 并需要复用 `previous_response_id`,以避免出现 “No tool call found for function call output” 错误。
----
-
## 部署方式
### 方式一:脚本安装(推荐)
@@ -246,6 +244,18 @@ docker-compose -f docker-compose.local.yml logs -f sub2api
**推荐:** 使用 `docker-compose.local.yml`(脚本部署)以便更轻松地管理数据。
+#### 启用“数据管理”功能(datamanagementd)
+
+如需启用管理后台“数据管理”,需要额外部署宿主机数据管理进程 `datamanagementd`。
+
+关键点:
+
+- 主进程固定探测:`/tmp/sub2api-datamanagement.sock`
+- 只有该 Socket 可连通时,数据管理功能才会开启
+- Docker 场景需将宿主机 Socket 挂载到容器同路径
+
+详细部署步骤见:`deploy/DATAMANAGEMENTD_CN.md`
+
#### 访问
在浏览器中打开 `http://你的服务器IP:8080`
diff --git a/backend/.golangci.yml b/backend/.golangci.yml
index 3ec692a8..68b76751 100644
--- a/backend/.golangci.yml
+++ b/backend/.golangci.yml
@@ -5,6 +5,7 @@ linters:
enable:
- depguard
- errcheck
+ - gosec
- govet
- ineffassign
- staticcheck
@@ -42,6 +43,22 @@ linters:
desc: "handler must not import gorm"
- pkg: github.com/redis/go-redis/v9
desc: "handler must not import redis"
+ gosec:
+ excludes:
+ - G101
+ - G103
+ - G104
+ - G109
+ - G115
+ - G201
+ - G202
+ - G301
+ - G302
+ - G304
+ - G306
+ - G404
+ severity: high
+ confidence: high
errcheck:
# Report about not checking of errors in type assertions: `a := b.(MyStruct)`.
# Such cases aren't reported by default.
diff --git a/backend/.gosec.json b/backend/.gosec.json
deleted file mode 100644
index b34e140c..00000000
--- a/backend/.gosec.json
+++ /dev/null
@@ -1,5 +0,0 @@
-{
- "global": {
- "exclude": "G704"
- }
-}
diff --git a/backend/Makefile b/backend/Makefile
index 89db1104..7084ccb9 100644
--- a/backend/Makefile
+++ b/backend/Makefile
@@ -1,7 +1,14 @@
-.PHONY: build test test-unit test-integration test-e2e
+.PHONY: build generate test test-unit test-integration test-e2e
+
+VERSION ?= $(shell tr -d '\r\n' < ./cmd/server/VERSION)
+LDFLAGS ?= -s -w -X main.Version=$(VERSION)
build:
- go build -o bin/server ./cmd/server
+ CGO_ENABLED=0 go build -ldflags="$(LDFLAGS)" -trimpath -o bin/server ./cmd/server
+
+generate:
+ go generate ./ent
+ go generate ./cmd/server
test:
go test ./...
diff --git a/backend/cmd/jwtgen/main.go b/backend/cmd/jwtgen/main.go
index 2ff7358b..bc001693 100644
--- a/backend/cmd/jwtgen/main.go
+++ b/backend/cmd/jwtgen/main.go
@@ -33,7 +33,7 @@ func main() {
}()
userRepo := repository.NewUserRepository(client, sqlDB)
- authService := service.NewAuthService(userRepo, nil, nil, cfg, nil, nil, nil, nil, nil)
+ authService := service.NewAuthService(userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION
index c0c68bab..32844913 100644
--- a/backend/cmd/server/VERSION
+++ b/backend/cmd/server/VERSION
@@ -1 +1 @@
-0.1.85
+0.1.88
\ No newline at end of file
diff --git a/backend/cmd/server/wire.go b/backend/cmd/server/wire.go
index 1ba6b184..cbf89ba3 100644
--- a/backend/cmd/server/wire.go
+++ b/backend/cmd/server/wire.go
@@ -7,6 +7,7 @@ import (
"context"
"log"
"net/http"
+ "sync"
"time"
"github.com/Wei-Shaw/sub2api/ent"
@@ -84,16 +85,19 @@ func provideCleanup(
openaiOAuth *service.OpenAIOAuthService,
geminiOAuth *service.GeminiOAuthService,
antigravityOAuth *service.AntigravityOAuthService,
+ openAIGateway *service.OpenAIGatewayService,
) func() {
return func() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
- // Cleanup steps in reverse dependency order
- cleanupSteps := []struct {
+ type cleanupStep struct {
name string
fn func() error
- }{
+ }
+
+ // 应用层清理步骤可并行执行,基础设施资源(Redis/Ent)最后按顺序关闭。
+ parallelSteps := []cleanupStep{
{"OpsScheduledReportService", func() error {
if opsScheduledReport != nil {
opsScheduledReport.Stop()
@@ -206,23 +210,60 @@ func provideCleanup(
antigravityOAuth.Stop()
return nil
}},
+ {"OpenAIWSPool", func() error {
+ if openAIGateway != nil {
+ openAIGateway.CloseOpenAIWSPool()
+ }
+ return nil
+ }},
+ }
+
+ infraSteps := []cleanupStep{
{"Redis", func() error {
+ if rdb == nil {
+ return nil
+ }
return rdb.Close()
}},
{"Ent", func() error {
+ if entClient == nil {
+ return nil
+ }
return entClient.Close()
}},
}
- for _, step := range cleanupSteps {
- if err := step.fn(); err != nil {
- log.Printf("[Cleanup] %s failed: %v", step.name, err)
- // Continue with remaining cleanup steps even if one fails
- } else {
+ runParallel := func(steps []cleanupStep) {
+ var wg sync.WaitGroup
+ for i := range steps {
+ step := steps[i]
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ if err := step.fn(); err != nil {
+ log.Printf("[Cleanup] %s failed: %v", step.name, err)
+ return
+ }
+ log.Printf("[Cleanup] %s succeeded", step.name)
+ }()
+ }
+ wg.Wait()
+ }
+
+ runSequential := func(steps []cleanupStep) {
+ for i := range steps {
+ step := steps[i]
+ if err := step.fn(); err != nil {
+ log.Printf("[Cleanup] %s failed: %v", step.name, err)
+ continue
+ }
log.Printf("[Cleanup] %s succeeded", step.name)
}
}
+ runParallel(parallelSteps)
+ runSequential(infraSteps)
+
// Check if context timed out
select {
case <-ctx.Done():
diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go
index 287f8176..90709f5b 100644
--- a/backend/cmd/server/wire_gen.go
+++ b/backend/cmd/server/wire_gen.go
@@ -19,6 +19,7 @@ import (
"github.com/redis/go-redis/v9"
"log"
"net/http"
+ "sync"
"time"
)
@@ -47,7 +48,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
redisClient := repository.ProvideRedis(configConfig)
refreshTokenCache := repository.NewRefreshTokenCache(redisClient)
settingRepository := repository.NewSettingRepository(client)
- settingService := service.NewSettingService(settingRepository, configConfig)
+ groupRepository := repository.NewGroupRepository(client, db)
+ settingService := service.ProvideSettingService(settingRepository, groupRepository, configConfig)
emailCache := repository.NewEmailCache(redisClient)
emailService := service.NewEmailService(settingRepository, emailCache)
turnstileVerifier := repository.NewTurnstileVerifier()
@@ -58,15 +60,14 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
userSubscriptionRepository := repository.NewUserSubscriptionRepository(client)
billingCacheService := service.NewBillingCacheService(billingCache, userRepository, userSubscriptionRepository, configConfig)
apiKeyRepository := repository.NewAPIKeyRepository(client)
- groupRepository := repository.NewGroupRepository(client, db)
userGroupRateRepository := repository.NewUserGroupRateRepository(db)
apiKeyCache := repository.NewAPIKeyCache(redisClient)
apiKeyService := service.NewAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, userGroupRateRepository, apiKeyCache, configConfig)
apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService)
promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator)
- authService := service.NewAuthService(userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService)
- userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator, billingCache)
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, client, configConfig)
+ authService := service.NewAuthService(userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService)
+ userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator, billingCache)
redeemCache := repository.NewRedeemCache(redisClient)
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator)
secretEncryptor, err := repository.NewAESEncryptor(configConfig)
@@ -102,7 +103,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
proxyRepository := repository.NewProxyRepository(client, db)
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
proxyLatencyCache := repository.NewProxyLatencyCache(redisClient)
- adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator)
+ adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService)
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
adminUserHandler := admin.NewUserHandler(adminService, concurrencyService)
@@ -113,7 +114,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
openAIOAuthService := service.NewOpenAIOAuthService(proxyRepository, openAIOAuthClient)
geminiOAuthClient := repository.NewGeminiOAuthClient(configConfig)
geminiCliCodeAssistClient := repository.NewGeminiCliCodeAssistClient()
- geminiOAuthService := service.NewGeminiOAuthService(proxyRepository, geminiOAuthClient, geminiCliCodeAssistClient, configConfig)
+ driveClient := repository.NewGeminiDriveClient()
+ geminiOAuthService := service.NewGeminiOAuthService(proxyRepository, geminiOAuthClient, geminiCliCodeAssistClient, driveClient, configConfig)
antigravityOAuthService := service.NewAntigravityOAuthService(proxyRepository)
geminiQuotaService := service.NewGeminiQuotaService(configConfig, settingRepository)
tempUnschedCache := repository.NewTempUnschedCache(redisClient)
@@ -136,14 +138,17 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig)
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig)
- accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, compositeTokenCacheInvalidator)
+ rpmCache := repository.NewRPMCache(redisClient)
+ accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, rpmCache, compositeTokenCacheInvalidator)
adminAnnouncementHandler := admin.NewAnnouncementHandler(announcementService)
+ dataManagementService := service.NewDataManagementService()
+ dataManagementHandler := admin.NewDataManagementHandler(dataManagementService)
oAuthHandler := admin.NewOAuthHandler(oAuthService)
openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService)
geminiOAuthHandler := admin.NewGeminiOAuthHandler(geminiOAuthService)
antigravityOAuthHandler := admin.NewAntigravityOAuthHandler(antigravityOAuthService)
proxyHandler := admin.NewProxyHandler(adminService)
- adminRedeemHandler := admin.NewRedeemHandler(adminService)
+ adminRedeemHandler := admin.NewRedeemHandler(adminService, redeemService)
promoHandler := admin.NewPromoHandler(promoService)
opsRepository := repository.NewOpsRepository(db)
pricingRemoteClient := repository.ProvidePricingRemoteClient(configConfig)
@@ -156,13 +161,18 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
claudeTokenProvider := service.NewClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService)
digestSessionStore := service.NewDigestSessionStore()
- gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, digestSessionStore)
+ gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore)
openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink)
- settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService)
+ soraS3Storage := service.NewSoraS3Storage(settingService)
+ settingService.SetOnS3UpdateCallback(soraS3Storage.RefreshClient)
+ soraGenerationRepository := repository.NewSoraGenerationRepository(db)
+ soraQuotaService := service.NewSoraQuotaService(userRepository, groupRepository, settingService)
+ soraGenerationService := service.NewSoraGenerationService(soraGenerationRepository, soraS3Storage, soraQuotaService)
+ settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, soraS3Storage)
opsHandler := admin.NewOpsHandler(opsService)
updateCache := repository.NewUpdateCache(redisClient)
gitHubReleaseClient := repository.ProvideGitHubReleaseClient(configConfig)
@@ -183,19 +193,21 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
errorPassthroughCache := repository.NewErrorPassthroughCache(redisClient)
errorPassthroughService := service.NewErrorPassthroughService(errorPassthroughRepository, errorPassthroughCache)
errorPassthroughHandler := admin.NewErrorPassthroughHandler(errorPassthroughService)
- adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler)
+ adminAPIKeyHandler := admin.NewAdminAPIKeyHandler(adminService)
+ adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, adminAPIKeyHandler)
usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig)
- gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, configConfig)
+ gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, configConfig, settingService)
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, configConfig)
soraSDKClient := service.ProvideSoraSDKClient(configConfig, httpUpstream, openAITokenProvider, accountRepository, soraAccountRepository)
soraMediaStorage := service.ProvideSoraMediaStorage(configConfig)
- soraGatewayService := service.NewSoraGatewayService(soraSDKClient, soraMediaStorage, rateLimitService, configConfig)
+ soraGatewayService := service.NewSoraGatewayService(soraSDKClient, rateLimitService, httpUpstream, configConfig)
+ soraClientHandler := handler.NewSoraClientHandler(soraGenerationService, soraQuotaService, soraS3Storage, soraGatewayService, gatewayService, soraMediaStorage, apiKeyService)
soraGatewayHandler := handler.NewSoraGatewayHandler(gatewayService, soraGatewayService, concurrencyService, billingCacheService, usageRecordWorkerPool, configConfig)
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
totpHandler := handler.NewTotpHandler(totpService)
idempotencyCoordinator := service.ProvideIdempotencyCoordinator(idempotencyRepository, configConfig)
idempotencyCleanupService := service.ProvideIdempotencyCleanupService(idempotencyRepository, configConfig)
- handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, announcementHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, soraGatewayHandler, handlerSettingHandler, totpHandler, idempotencyCoordinator, idempotencyCleanupService)
+ handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, announcementHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, soraGatewayHandler, soraClientHandler, handlerSettingHandler, totpHandler, idempotencyCoordinator, idempotencyCleanupService)
jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService)
adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService)
apiKeyAuthMiddleware := middleware.NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, configConfig)
@@ -210,7 +222,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig)
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
- v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService)
+ v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService)
application := &Application{
Server: httpServer,
Cleanup: v,
@@ -257,15 +269,18 @@ func provideCleanup(
openaiOAuth *service.OpenAIOAuthService,
geminiOAuth *service.GeminiOAuthService,
antigravityOAuth *service.AntigravityOAuthService,
+ openAIGateway *service.OpenAIGatewayService,
) func() {
return func() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
- cleanupSteps := []struct {
+ type cleanupStep struct {
name string
fn func() error
- }{
+ }
+
+ parallelSteps := []cleanupStep{
{"OpsScheduledReportService", func() error {
if opsScheduledReport != nil {
opsScheduledReport.Stop()
@@ -378,23 +393,60 @@ func provideCleanup(
antigravityOAuth.Stop()
return nil
}},
+ {"OpenAIWSPool", func() error {
+ if openAIGateway != nil {
+ openAIGateway.CloseOpenAIWSPool()
+ }
+ return nil
+ }},
+ }
+
+ infraSteps := []cleanupStep{
{"Redis", func() error {
+ if rdb == nil {
+ return nil
+ }
return rdb.Close()
}},
{"Ent", func() error {
+ if entClient == nil {
+ return nil
+ }
return entClient.Close()
}},
}
- for _, step := range cleanupSteps {
- if err := step.fn(); err != nil {
- log.Printf("[Cleanup] %s failed: %v", step.name, err)
+ runParallel := func(steps []cleanupStep) {
+ var wg sync.WaitGroup
+ for i := range steps {
+ step := steps[i]
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ if err := step.fn(); err != nil {
+ log.Printf("[Cleanup] %s failed: %v", step.name, err)
+ return
+ }
+ log.Printf("[Cleanup] %s succeeded", step.name)
+ }()
+ }
+ wg.Wait()
+ }
- } else {
+ runSequential := func(steps []cleanupStep) {
+ for i := range steps {
+ step := steps[i]
+ if err := step.fn(); err != nil {
+ log.Printf("[Cleanup] %s failed: %v", step.name, err)
+ continue
+ }
log.Printf("[Cleanup] %s succeeded", step.name)
}
}
+ runParallel(parallelSteps)
+ runSequential(infraSteps)
+
select {
case <-ctx.Done():
log.Printf("[Cleanup] Warning: cleanup timed out after 10 seconds")
diff --git a/backend/cmd/server/wire_gen_test.go b/backend/cmd/server/wire_gen_test.go
new file mode 100644
index 00000000..9fb9888d
--- /dev/null
+++ b/backend/cmd/server/wire_gen_test.go
@@ -0,0 +1,81 @@
+package main
+
+import (
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/Wei-Shaw/sub2api/internal/handler"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/stretchr/testify/require"
+)
+
+func TestProvideServiceBuildInfo(t *testing.T) {
+ in := handler.BuildInfo{
+ Version: "v-test",
+ BuildType: "release",
+ }
+ out := provideServiceBuildInfo(in)
+ require.Equal(t, in.Version, out.Version)
+ require.Equal(t, in.BuildType, out.BuildType)
+}
+
+func TestProvideCleanup_WithMinimalDependencies_NoPanic(t *testing.T) {
+ cfg := &config.Config{}
+
+ oauthSvc := service.NewOAuthService(nil, nil)
+ openAIOAuthSvc := service.NewOpenAIOAuthService(nil, nil)
+ geminiOAuthSvc := service.NewGeminiOAuthService(nil, nil, nil, nil, cfg)
+ antigravityOAuthSvc := service.NewAntigravityOAuthService(nil)
+
+ tokenRefreshSvc := service.NewTokenRefreshService(
+ nil,
+ oauthSvc,
+ openAIOAuthSvc,
+ geminiOAuthSvc,
+ antigravityOAuthSvc,
+ nil,
+ nil,
+ cfg,
+ )
+ accountExpirySvc := service.NewAccountExpiryService(nil, time.Second)
+ subscriptionExpirySvc := service.NewSubscriptionExpiryService(nil, time.Second)
+ pricingSvc := service.NewPricingService(cfg, nil)
+ emailQueueSvc := service.NewEmailQueueService(nil, 1)
+ billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, cfg)
+ idempotencyCleanupSvc := service.NewIdempotencyCleanupService(nil, cfg)
+ schedulerSnapshotSvc := service.NewSchedulerSnapshotService(nil, nil, nil, nil, cfg)
+ opsSystemLogSinkSvc := service.NewOpsSystemLogSink(nil)
+
+ cleanup := provideCleanup(
+ nil, // entClient
+ nil, // redis
+ &service.OpsMetricsCollector{},
+ &service.OpsAggregationService{},
+ &service.OpsAlertEvaluatorService{},
+ &service.OpsCleanupService{},
+ &service.OpsScheduledReportService{},
+ opsSystemLogSinkSvc,
+ &service.SoraMediaCleanupService{},
+ schedulerSnapshotSvc,
+ tokenRefreshSvc,
+ accountExpirySvc,
+ subscriptionExpirySvc,
+ &service.UsageCleanupService{},
+ idempotencyCleanupSvc,
+ pricingSvc,
+ emailQueueSvc,
+ billingCacheSvc,
+ &service.UsageRecordWorkerPool{},
+ &service.SubscriptionService{},
+ oauthSvc,
+ openAIOAuthSvc,
+ geminiOAuthSvc,
+ antigravityOAuthSvc,
+ nil, // openAIGateway
+ )
+
+ require.NotPanics(t, func() {
+ cleanup()
+ })
+}
diff --git a/backend/ent/account.go b/backend/ent/account.go
index 038aa7e5..c77002b3 100644
--- a/backend/ent/account.go
+++ b/backend/ent/account.go
@@ -63,6 +63,10 @@ type Account struct {
RateLimitResetAt *time.Time `json:"rate_limit_reset_at,omitempty"`
// OverloadUntil holds the value of the "overload_until" field.
OverloadUntil *time.Time `json:"overload_until,omitempty"`
+ // TempUnschedulableUntil holds the value of the "temp_unschedulable_until" field.
+ TempUnschedulableUntil *time.Time `json:"temp_unschedulable_until,omitempty"`
+ // TempUnschedulableReason holds the value of the "temp_unschedulable_reason" field.
+ TempUnschedulableReason *string `json:"temp_unschedulable_reason,omitempty"`
// SessionWindowStart holds the value of the "session_window_start" field.
SessionWindowStart *time.Time `json:"session_window_start,omitempty"`
// SessionWindowEnd holds the value of the "session_window_end" field.
@@ -141,9 +145,9 @@ func (*Account) scanValues(columns []string) ([]any, error) {
values[i] = new(sql.NullFloat64)
case account.FieldID, account.FieldProxyID, account.FieldConcurrency, account.FieldPriority:
values[i] = new(sql.NullInt64)
- case account.FieldName, account.FieldNotes, account.FieldPlatform, account.FieldType, account.FieldStatus, account.FieldErrorMessage, account.FieldSessionWindowStatus:
+ case account.FieldName, account.FieldNotes, account.FieldPlatform, account.FieldType, account.FieldStatus, account.FieldErrorMessage, account.FieldTempUnschedulableReason, account.FieldSessionWindowStatus:
values[i] = new(sql.NullString)
- case account.FieldCreatedAt, account.FieldUpdatedAt, account.FieldDeletedAt, account.FieldLastUsedAt, account.FieldExpiresAt, account.FieldRateLimitedAt, account.FieldRateLimitResetAt, account.FieldOverloadUntil, account.FieldSessionWindowStart, account.FieldSessionWindowEnd:
+ case account.FieldCreatedAt, account.FieldUpdatedAt, account.FieldDeletedAt, account.FieldLastUsedAt, account.FieldExpiresAt, account.FieldRateLimitedAt, account.FieldRateLimitResetAt, account.FieldOverloadUntil, account.FieldTempUnschedulableUntil, account.FieldSessionWindowStart, account.FieldSessionWindowEnd:
values[i] = new(sql.NullTime)
default:
values[i] = new(sql.UnknownType)
@@ -311,6 +315,20 @@ func (_m *Account) assignValues(columns []string, values []any) error {
_m.OverloadUntil = new(time.Time)
*_m.OverloadUntil = value.Time
}
+ case account.FieldTempUnschedulableUntil:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field temp_unschedulable_until", values[i])
+ } else if value.Valid {
+ _m.TempUnschedulableUntil = new(time.Time)
+ *_m.TempUnschedulableUntil = value.Time
+ }
+ case account.FieldTempUnschedulableReason:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field temp_unschedulable_reason", values[i])
+ } else if value.Valid {
+ _m.TempUnschedulableReason = new(string)
+ *_m.TempUnschedulableReason = value.String
+ }
case account.FieldSessionWindowStart:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field session_window_start", values[i])
@@ -472,6 +490,16 @@ func (_m *Account) String() string {
builder.WriteString(v.Format(time.ANSIC))
}
builder.WriteString(", ")
+ if v := _m.TempUnschedulableUntil; v != nil {
+ builder.WriteString("temp_unschedulable_until=")
+ builder.WriteString(v.Format(time.ANSIC))
+ }
+ builder.WriteString(", ")
+ if v := _m.TempUnschedulableReason; v != nil {
+ builder.WriteString("temp_unschedulable_reason=")
+ builder.WriteString(*v)
+ }
+ builder.WriteString(", ")
if v := _m.SessionWindowStart; v != nil {
builder.WriteString("session_window_start=")
builder.WriteString(v.Format(time.ANSIC))
diff --git a/backend/ent/account/account.go b/backend/ent/account/account.go
index 73c0e8c2..1fc34620 100644
--- a/backend/ent/account/account.go
+++ b/backend/ent/account/account.go
@@ -59,6 +59,10 @@ const (
FieldRateLimitResetAt = "rate_limit_reset_at"
// FieldOverloadUntil holds the string denoting the overload_until field in the database.
FieldOverloadUntil = "overload_until"
+ // FieldTempUnschedulableUntil holds the string denoting the temp_unschedulable_until field in the database.
+ FieldTempUnschedulableUntil = "temp_unschedulable_until"
+ // FieldTempUnschedulableReason holds the string denoting the temp_unschedulable_reason field in the database.
+ FieldTempUnschedulableReason = "temp_unschedulable_reason"
// FieldSessionWindowStart holds the string denoting the session_window_start field in the database.
FieldSessionWindowStart = "session_window_start"
// FieldSessionWindowEnd holds the string denoting the session_window_end field in the database.
@@ -128,6 +132,8 @@ var Columns = []string{
FieldRateLimitedAt,
FieldRateLimitResetAt,
FieldOverloadUntil,
+ FieldTempUnschedulableUntil,
+ FieldTempUnschedulableReason,
FieldSessionWindowStart,
FieldSessionWindowEnd,
FieldSessionWindowStatus,
@@ -299,6 +305,16 @@ func ByOverloadUntil(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldOverloadUntil, opts...).ToFunc()
}
+// ByTempUnschedulableUntil orders the results by the temp_unschedulable_until field.
+func ByTempUnschedulableUntil(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldTempUnschedulableUntil, opts...).ToFunc()
+}
+
+// ByTempUnschedulableReason orders the results by the temp_unschedulable_reason field.
+func ByTempUnschedulableReason(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldTempUnschedulableReason, opts...).ToFunc()
+}
+
// BySessionWindowStart orders the results by the session_window_start field.
func BySessionWindowStart(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldSessionWindowStart, opts...).ToFunc()
diff --git a/backend/ent/account/where.go b/backend/ent/account/where.go
index dea1127a..54db1dcb 100644
--- a/backend/ent/account/where.go
+++ b/backend/ent/account/where.go
@@ -155,6 +155,16 @@ func OverloadUntil(v time.Time) predicate.Account {
return predicate.Account(sql.FieldEQ(FieldOverloadUntil, v))
}
+// TempUnschedulableUntil applies equality check predicate on the "temp_unschedulable_until" field. It's identical to TempUnschedulableUntilEQ.
+func TempUnschedulableUntil(v time.Time) predicate.Account {
+ return predicate.Account(sql.FieldEQ(FieldTempUnschedulableUntil, v))
+}
+
+// TempUnschedulableReason applies equality check predicate on the "temp_unschedulable_reason" field. It's identical to TempUnschedulableReasonEQ.
+func TempUnschedulableReason(v string) predicate.Account {
+ return predicate.Account(sql.FieldEQ(FieldTempUnschedulableReason, v))
+}
+
// SessionWindowStart applies equality check predicate on the "session_window_start" field. It's identical to SessionWindowStartEQ.
func SessionWindowStart(v time.Time) predicate.Account {
return predicate.Account(sql.FieldEQ(FieldSessionWindowStart, v))
@@ -1130,6 +1140,131 @@ func OverloadUntilNotNil() predicate.Account {
return predicate.Account(sql.FieldNotNull(FieldOverloadUntil))
}
+// TempUnschedulableUntilEQ applies the EQ predicate on the "temp_unschedulable_until" field.
+func TempUnschedulableUntilEQ(v time.Time) predicate.Account {
+ return predicate.Account(sql.FieldEQ(FieldTempUnschedulableUntil, v))
+}
+
+// TempUnschedulableUntilNEQ applies the NEQ predicate on the "temp_unschedulable_until" field.
+func TempUnschedulableUntilNEQ(v time.Time) predicate.Account {
+ return predicate.Account(sql.FieldNEQ(FieldTempUnschedulableUntil, v))
+}
+
+// TempUnschedulableUntilIn applies the In predicate on the "temp_unschedulable_until" field.
+func TempUnschedulableUntilIn(vs ...time.Time) predicate.Account {
+ return predicate.Account(sql.FieldIn(FieldTempUnschedulableUntil, vs...))
+}
+
+// TempUnschedulableUntilNotIn applies the NotIn predicate on the "temp_unschedulable_until" field.
+func TempUnschedulableUntilNotIn(vs ...time.Time) predicate.Account {
+ return predicate.Account(sql.FieldNotIn(FieldTempUnschedulableUntil, vs...))
+}
+
+// TempUnschedulableUntilGT applies the GT predicate on the "temp_unschedulable_until" field.
+func TempUnschedulableUntilGT(v time.Time) predicate.Account {
+ return predicate.Account(sql.FieldGT(FieldTempUnschedulableUntil, v))
+}
+
+// TempUnschedulableUntilGTE applies the GTE predicate on the "temp_unschedulable_until" field.
+func TempUnschedulableUntilGTE(v time.Time) predicate.Account {
+ return predicate.Account(sql.FieldGTE(FieldTempUnschedulableUntil, v))
+}
+
+// TempUnschedulableUntilLT applies the LT predicate on the "temp_unschedulable_until" field.
+func TempUnschedulableUntilLT(v time.Time) predicate.Account {
+ return predicate.Account(sql.FieldLT(FieldTempUnschedulableUntil, v))
+}
+
+// TempUnschedulableUntilLTE applies the LTE predicate on the "temp_unschedulable_until" field.
+func TempUnschedulableUntilLTE(v time.Time) predicate.Account {
+ return predicate.Account(sql.FieldLTE(FieldTempUnschedulableUntil, v))
+}
+
+// TempUnschedulableUntilIsNil applies the IsNil predicate on the "temp_unschedulable_until" field.
+func TempUnschedulableUntilIsNil() predicate.Account {
+ return predicate.Account(sql.FieldIsNull(FieldTempUnschedulableUntil))
+}
+
+// TempUnschedulableUntilNotNil applies the NotNil predicate on the "temp_unschedulable_until" field.
+func TempUnschedulableUntilNotNil() predicate.Account {
+ return predicate.Account(sql.FieldNotNull(FieldTempUnschedulableUntil))
+}
+
+// TempUnschedulableReasonEQ applies the EQ predicate on the "temp_unschedulable_reason" field.
+func TempUnschedulableReasonEQ(v string) predicate.Account {
+ return predicate.Account(sql.FieldEQ(FieldTempUnschedulableReason, v))
+}
+
+// TempUnschedulableReasonNEQ applies the NEQ predicate on the "temp_unschedulable_reason" field.
+func TempUnschedulableReasonNEQ(v string) predicate.Account {
+ return predicate.Account(sql.FieldNEQ(FieldTempUnschedulableReason, v))
+}
+
+// TempUnschedulableReasonIn applies the In predicate on the "temp_unschedulable_reason" field.
+func TempUnschedulableReasonIn(vs ...string) predicate.Account {
+ return predicate.Account(sql.FieldIn(FieldTempUnschedulableReason, vs...))
+}
+
+// TempUnschedulableReasonNotIn applies the NotIn predicate on the "temp_unschedulable_reason" field.
+func TempUnschedulableReasonNotIn(vs ...string) predicate.Account {
+ return predicate.Account(sql.FieldNotIn(FieldTempUnschedulableReason, vs...))
+}
+
+// TempUnschedulableReasonGT applies the GT predicate on the "temp_unschedulable_reason" field.
+func TempUnschedulableReasonGT(v string) predicate.Account {
+ return predicate.Account(sql.FieldGT(FieldTempUnschedulableReason, v))
+}
+
+// TempUnschedulableReasonGTE applies the GTE predicate on the "temp_unschedulable_reason" field.
+func TempUnschedulableReasonGTE(v string) predicate.Account {
+ return predicate.Account(sql.FieldGTE(FieldTempUnschedulableReason, v))
+}
+
+// TempUnschedulableReasonLT applies the LT predicate on the "temp_unschedulable_reason" field.
+func TempUnschedulableReasonLT(v string) predicate.Account {
+ return predicate.Account(sql.FieldLT(FieldTempUnschedulableReason, v))
+}
+
+// TempUnschedulableReasonLTE applies the LTE predicate on the "temp_unschedulable_reason" field.
+func TempUnschedulableReasonLTE(v string) predicate.Account {
+ return predicate.Account(sql.FieldLTE(FieldTempUnschedulableReason, v))
+}
+
+// TempUnschedulableReasonContains applies the Contains predicate on the "temp_unschedulable_reason" field.
+func TempUnschedulableReasonContains(v string) predicate.Account {
+ return predicate.Account(sql.FieldContains(FieldTempUnschedulableReason, v))
+}
+
+// TempUnschedulableReasonHasPrefix applies the HasPrefix predicate on the "temp_unschedulable_reason" field.
+func TempUnschedulableReasonHasPrefix(v string) predicate.Account {
+ return predicate.Account(sql.FieldHasPrefix(FieldTempUnschedulableReason, v))
+}
+
+// TempUnschedulableReasonHasSuffix applies the HasSuffix predicate on the "temp_unschedulable_reason" field.
+func TempUnschedulableReasonHasSuffix(v string) predicate.Account {
+ return predicate.Account(sql.FieldHasSuffix(FieldTempUnschedulableReason, v))
+}
+
+// TempUnschedulableReasonIsNil applies the IsNil predicate on the "temp_unschedulable_reason" field.
+func TempUnschedulableReasonIsNil() predicate.Account {
+ return predicate.Account(sql.FieldIsNull(FieldTempUnschedulableReason))
+}
+
+// TempUnschedulableReasonNotNil applies the NotNil predicate on the "temp_unschedulable_reason" field.
+func TempUnschedulableReasonNotNil() predicate.Account {
+ return predicate.Account(sql.FieldNotNull(FieldTempUnschedulableReason))
+}
+
+// TempUnschedulableReasonEqualFold applies the EqualFold predicate on the "temp_unschedulable_reason" field.
+func TempUnschedulableReasonEqualFold(v string) predicate.Account {
+ return predicate.Account(sql.FieldEqualFold(FieldTempUnschedulableReason, v))
+}
+
+// TempUnschedulableReasonContainsFold applies the ContainsFold predicate on the "temp_unschedulable_reason" field.
+func TempUnschedulableReasonContainsFold(v string) predicate.Account {
+ return predicate.Account(sql.FieldContainsFold(FieldTempUnschedulableReason, v))
+}
+
// SessionWindowStartEQ applies the EQ predicate on the "session_window_start" field.
func SessionWindowStartEQ(v time.Time) predicate.Account {
return predicate.Account(sql.FieldEQ(FieldSessionWindowStart, v))
diff --git a/backend/ent/account_create.go b/backend/ent/account_create.go
index 42a561cf..963ffee8 100644
--- a/backend/ent/account_create.go
+++ b/backend/ent/account_create.go
@@ -293,6 +293,34 @@ func (_c *AccountCreate) SetNillableOverloadUntil(v *time.Time) *AccountCreate {
return _c
}
+// SetTempUnschedulableUntil sets the "temp_unschedulable_until" field.
+func (_c *AccountCreate) SetTempUnschedulableUntil(v time.Time) *AccountCreate {
+ _c.mutation.SetTempUnschedulableUntil(v)
+ return _c
+}
+
+// SetNillableTempUnschedulableUntil sets the "temp_unschedulable_until" field if the given value is not nil.
+func (_c *AccountCreate) SetNillableTempUnschedulableUntil(v *time.Time) *AccountCreate {
+ if v != nil {
+ _c.SetTempUnschedulableUntil(*v)
+ }
+ return _c
+}
+
+// SetTempUnschedulableReason sets the "temp_unschedulable_reason" field.
+func (_c *AccountCreate) SetTempUnschedulableReason(v string) *AccountCreate {
+ _c.mutation.SetTempUnschedulableReason(v)
+ return _c
+}
+
+// SetNillableTempUnschedulableReason sets the "temp_unschedulable_reason" field if the given value is not nil.
+func (_c *AccountCreate) SetNillableTempUnschedulableReason(v *string) *AccountCreate {
+ if v != nil {
+ _c.SetTempUnschedulableReason(*v)
+ }
+ return _c
+}
+
// SetSessionWindowStart sets the "session_window_start" field.
func (_c *AccountCreate) SetSessionWindowStart(v time.Time) *AccountCreate {
_c.mutation.SetSessionWindowStart(v)
@@ -639,6 +667,14 @@ func (_c *AccountCreate) createSpec() (*Account, *sqlgraph.CreateSpec) {
_spec.SetField(account.FieldOverloadUntil, field.TypeTime, value)
_node.OverloadUntil = &value
}
+ if value, ok := _c.mutation.TempUnschedulableUntil(); ok {
+ _spec.SetField(account.FieldTempUnschedulableUntil, field.TypeTime, value)
+ _node.TempUnschedulableUntil = &value
+ }
+ if value, ok := _c.mutation.TempUnschedulableReason(); ok {
+ _spec.SetField(account.FieldTempUnschedulableReason, field.TypeString, value)
+ _node.TempUnschedulableReason = &value
+ }
if value, ok := _c.mutation.SessionWindowStart(); ok {
_spec.SetField(account.FieldSessionWindowStart, field.TypeTime, value)
_node.SessionWindowStart = &value
@@ -1080,6 +1116,42 @@ func (u *AccountUpsert) ClearOverloadUntil() *AccountUpsert {
return u
}
+// SetTempUnschedulableUntil sets the "temp_unschedulable_until" field.
+func (u *AccountUpsert) SetTempUnschedulableUntil(v time.Time) *AccountUpsert {
+ u.Set(account.FieldTempUnschedulableUntil, v)
+ return u
+}
+
+// UpdateTempUnschedulableUntil sets the "temp_unschedulable_until" field to the value that was provided on create.
+func (u *AccountUpsert) UpdateTempUnschedulableUntil() *AccountUpsert {
+ u.SetExcluded(account.FieldTempUnschedulableUntil)
+ return u
+}
+
+// ClearTempUnschedulableUntil clears the value of the "temp_unschedulable_until" field.
+func (u *AccountUpsert) ClearTempUnschedulableUntil() *AccountUpsert {
+ u.SetNull(account.FieldTempUnschedulableUntil)
+ return u
+}
+
+// SetTempUnschedulableReason sets the "temp_unschedulable_reason" field.
+func (u *AccountUpsert) SetTempUnschedulableReason(v string) *AccountUpsert {
+ u.Set(account.FieldTempUnschedulableReason, v)
+ return u
+}
+
+// UpdateTempUnschedulableReason sets the "temp_unschedulable_reason" field to the value that was provided on create.
+func (u *AccountUpsert) UpdateTempUnschedulableReason() *AccountUpsert {
+ u.SetExcluded(account.FieldTempUnschedulableReason)
+ return u
+}
+
+// ClearTempUnschedulableReason clears the value of the "temp_unschedulable_reason" field.
+func (u *AccountUpsert) ClearTempUnschedulableReason() *AccountUpsert {
+ u.SetNull(account.FieldTempUnschedulableReason)
+ return u
+}
+
// SetSessionWindowStart sets the "session_window_start" field.
func (u *AccountUpsert) SetSessionWindowStart(v time.Time) *AccountUpsert {
u.Set(account.FieldSessionWindowStart, v)
@@ -1557,6 +1629,48 @@ func (u *AccountUpsertOne) ClearOverloadUntil() *AccountUpsertOne {
})
}
+// SetTempUnschedulableUntil sets the "temp_unschedulable_until" field.
+func (u *AccountUpsertOne) SetTempUnschedulableUntil(v time.Time) *AccountUpsertOne {
+ return u.Update(func(s *AccountUpsert) {
+ s.SetTempUnschedulableUntil(v)
+ })
+}
+
+// UpdateTempUnschedulableUntil sets the "temp_unschedulable_until" field to the value that was provided on create.
+func (u *AccountUpsertOne) UpdateTempUnschedulableUntil() *AccountUpsertOne {
+ return u.Update(func(s *AccountUpsert) {
+ s.UpdateTempUnschedulableUntil()
+ })
+}
+
+// ClearTempUnschedulableUntil clears the value of the "temp_unschedulable_until" field.
+func (u *AccountUpsertOne) ClearTempUnschedulableUntil() *AccountUpsertOne {
+ return u.Update(func(s *AccountUpsert) {
+ s.ClearTempUnschedulableUntil()
+ })
+}
+
+// SetTempUnschedulableReason sets the "temp_unschedulable_reason" field.
+func (u *AccountUpsertOne) SetTempUnschedulableReason(v string) *AccountUpsertOne {
+ return u.Update(func(s *AccountUpsert) {
+ s.SetTempUnschedulableReason(v)
+ })
+}
+
+// UpdateTempUnschedulableReason sets the "temp_unschedulable_reason" field to the value that was provided on create.
+func (u *AccountUpsertOne) UpdateTempUnschedulableReason() *AccountUpsertOne {
+ return u.Update(func(s *AccountUpsert) {
+ s.UpdateTempUnschedulableReason()
+ })
+}
+
+// ClearTempUnschedulableReason clears the value of the "temp_unschedulable_reason" field.
+func (u *AccountUpsertOne) ClearTempUnschedulableReason() *AccountUpsertOne {
+ return u.Update(func(s *AccountUpsert) {
+ s.ClearTempUnschedulableReason()
+ })
+}
+
// SetSessionWindowStart sets the "session_window_start" field.
func (u *AccountUpsertOne) SetSessionWindowStart(v time.Time) *AccountUpsertOne {
return u.Update(func(s *AccountUpsert) {
@@ -2209,6 +2323,48 @@ func (u *AccountUpsertBulk) ClearOverloadUntil() *AccountUpsertBulk {
})
}
+// SetTempUnschedulableUntil sets the "temp_unschedulable_until" field.
+func (u *AccountUpsertBulk) SetTempUnschedulableUntil(v time.Time) *AccountUpsertBulk {
+ return u.Update(func(s *AccountUpsert) {
+ s.SetTempUnschedulableUntil(v)
+ })
+}
+
+// UpdateTempUnschedulableUntil sets the "temp_unschedulable_until" field to the value that was provided on create.
+func (u *AccountUpsertBulk) UpdateTempUnschedulableUntil() *AccountUpsertBulk {
+ return u.Update(func(s *AccountUpsert) {
+ s.UpdateTempUnschedulableUntil()
+ })
+}
+
+// ClearTempUnschedulableUntil clears the value of the "temp_unschedulable_until" field.
+func (u *AccountUpsertBulk) ClearTempUnschedulableUntil() *AccountUpsertBulk {
+ return u.Update(func(s *AccountUpsert) {
+ s.ClearTempUnschedulableUntil()
+ })
+}
+
+// SetTempUnschedulableReason sets the "temp_unschedulable_reason" field.
+func (u *AccountUpsertBulk) SetTempUnschedulableReason(v string) *AccountUpsertBulk {
+ return u.Update(func(s *AccountUpsert) {
+ s.SetTempUnschedulableReason(v)
+ })
+}
+
+// UpdateTempUnschedulableReason sets the "temp_unschedulable_reason" field to the value that was provided on create.
+func (u *AccountUpsertBulk) UpdateTempUnschedulableReason() *AccountUpsertBulk {
+ return u.Update(func(s *AccountUpsert) {
+ s.UpdateTempUnschedulableReason()
+ })
+}
+
+// ClearTempUnschedulableReason clears the value of the "temp_unschedulable_reason" field.
+func (u *AccountUpsertBulk) ClearTempUnschedulableReason() *AccountUpsertBulk {
+ return u.Update(func(s *AccountUpsert) {
+ s.ClearTempUnschedulableReason()
+ })
+}
+
// SetSessionWindowStart sets the "session_window_start" field.
func (u *AccountUpsertBulk) SetSessionWindowStart(v time.Time) *AccountUpsertBulk {
return u.Update(func(s *AccountUpsert) {
diff --git a/backend/ent/account_update.go b/backend/ent/account_update.go
index 63fab096..875888e0 100644
--- a/backend/ent/account_update.go
+++ b/backend/ent/account_update.go
@@ -376,6 +376,46 @@ func (_u *AccountUpdate) ClearOverloadUntil() *AccountUpdate {
return _u
}
+// SetTempUnschedulableUntil sets the "temp_unschedulable_until" field.
+func (_u *AccountUpdate) SetTempUnschedulableUntil(v time.Time) *AccountUpdate {
+ _u.mutation.SetTempUnschedulableUntil(v)
+ return _u
+}
+
+// SetNillableTempUnschedulableUntil sets the "temp_unschedulable_until" field if the given value is not nil.
+func (_u *AccountUpdate) SetNillableTempUnschedulableUntil(v *time.Time) *AccountUpdate {
+ if v != nil {
+ _u.SetTempUnschedulableUntil(*v)
+ }
+ return _u
+}
+
+// ClearTempUnschedulableUntil clears the value of the "temp_unschedulable_until" field.
+func (_u *AccountUpdate) ClearTempUnschedulableUntil() *AccountUpdate {
+ _u.mutation.ClearTempUnschedulableUntil()
+ return _u
+}
+
+// SetTempUnschedulableReason sets the "temp_unschedulable_reason" field.
+func (_u *AccountUpdate) SetTempUnschedulableReason(v string) *AccountUpdate {
+ _u.mutation.SetTempUnschedulableReason(v)
+ return _u
+}
+
+// SetNillableTempUnschedulableReason sets the "temp_unschedulable_reason" field if the given value is not nil.
+func (_u *AccountUpdate) SetNillableTempUnschedulableReason(v *string) *AccountUpdate {
+ if v != nil {
+ _u.SetTempUnschedulableReason(*v)
+ }
+ return _u
+}
+
+// ClearTempUnschedulableReason clears the value of the "temp_unschedulable_reason" field.
+func (_u *AccountUpdate) ClearTempUnschedulableReason() *AccountUpdate {
+ _u.mutation.ClearTempUnschedulableReason()
+ return _u
+}
+
// SetSessionWindowStart sets the "session_window_start" field.
func (_u *AccountUpdate) SetSessionWindowStart(v time.Time) *AccountUpdate {
_u.mutation.SetSessionWindowStart(v)
@@ -701,6 +741,18 @@ func (_u *AccountUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if _u.mutation.OverloadUntilCleared() {
_spec.ClearField(account.FieldOverloadUntil, field.TypeTime)
}
+ if value, ok := _u.mutation.TempUnschedulableUntil(); ok {
+ _spec.SetField(account.FieldTempUnschedulableUntil, field.TypeTime, value)
+ }
+ if _u.mutation.TempUnschedulableUntilCleared() {
+ _spec.ClearField(account.FieldTempUnschedulableUntil, field.TypeTime)
+ }
+ if value, ok := _u.mutation.TempUnschedulableReason(); ok {
+ _spec.SetField(account.FieldTempUnschedulableReason, field.TypeString, value)
+ }
+ if _u.mutation.TempUnschedulableReasonCleared() {
+ _spec.ClearField(account.FieldTempUnschedulableReason, field.TypeString)
+ }
if value, ok := _u.mutation.SessionWindowStart(); ok {
_spec.SetField(account.FieldSessionWindowStart, field.TypeTime, value)
}
@@ -1215,6 +1267,46 @@ func (_u *AccountUpdateOne) ClearOverloadUntil() *AccountUpdateOne {
return _u
}
+// SetTempUnschedulableUntil sets the "temp_unschedulable_until" field.
+func (_u *AccountUpdateOne) SetTempUnschedulableUntil(v time.Time) *AccountUpdateOne {
+ _u.mutation.SetTempUnschedulableUntil(v)
+ return _u
+}
+
+// SetNillableTempUnschedulableUntil sets the "temp_unschedulable_until" field if the given value is not nil.
+func (_u *AccountUpdateOne) SetNillableTempUnschedulableUntil(v *time.Time) *AccountUpdateOne {
+ if v != nil {
+ _u.SetTempUnschedulableUntil(*v)
+ }
+ return _u
+}
+
+// ClearTempUnschedulableUntil clears the value of the "temp_unschedulable_until" field.
+func (_u *AccountUpdateOne) ClearTempUnschedulableUntil() *AccountUpdateOne {
+ _u.mutation.ClearTempUnschedulableUntil()
+ return _u
+}
+
+// SetTempUnschedulableReason sets the "temp_unschedulable_reason" field.
+func (_u *AccountUpdateOne) SetTempUnschedulableReason(v string) *AccountUpdateOne {
+ _u.mutation.SetTempUnschedulableReason(v)
+ return _u
+}
+
+// SetNillableTempUnschedulableReason sets the "temp_unschedulable_reason" field if the given value is not nil.
+func (_u *AccountUpdateOne) SetNillableTempUnschedulableReason(v *string) *AccountUpdateOne {
+ if v != nil {
+ _u.SetTempUnschedulableReason(*v)
+ }
+ return _u
+}
+
+// ClearTempUnschedulableReason clears the value of the "temp_unschedulable_reason" field.
+func (_u *AccountUpdateOne) ClearTempUnschedulableReason() *AccountUpdateOne {
+ _u.mutation.ClearTempUnschedulableReason()
+ return _u
+}
+
// SetSessionWindowStart sets the "session_window_start" field.
func (_u *AccountUpdateOne) SetSessionWindowStart(v time.Time) *AccountUpdateOne {
_u.mutation.SetSessionWindowStart(v)
@@ -1570,6 +1662,18 @@ func (_u *AccountUpdateOne) sqlSave(ctx context.Context) (_node *Account, err er
if _u.mutation.OverloadUntilCleared() {
_spec.ClearField(account.FieldOverloadUntil, field.TypeTime)
}
+ if value, ok := _u.mutation.TempUnschedulableUntil(); ok {
+ _spec.SetField(account.FieldTempUnschedulableUntil, field.TypeTime, value)
+ }
+ if _u.mutation.TempUnschedulableUntilCleared() {
+ _spec.ClearField(account.FieldTempUnschedulableUntil, field.TypeTime)
+ }
+ if value, ok := _u.mutation.TempUnschedulableReason(); ok {
+ _spec.SetField(account.FieldTempUnschedulableReason, field.TypeString, value)
+ }
+ if _u.mutation.TempUnschedulableReasonCleared() {
+ _spec.ClearField(account.FieldTempUnschedulableReason, field.TypeString)
+ }
if value, ok := _u.mutation.SessionWindowStart(); ok {
_spec.SetField(account.FieldSessionWindowStart, field.TypeTime, value)
}
diff --git a/backend/ent/client.go b/backend/ent/client.go
index 504c1755..7ebbaa32 100644
--- a/backend/ent/client.go
+++ b/backend/ent/client.go
@@ -22,6 +22,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule"
"github.com/Wei-Shaw/sub2api/ent/group"
+ "github.com/Wei-Shaw/sub2api/ent/idempotencyrecord"
"github.com/Wei-Shaw/sub2api/ent/promocode"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
"github.com/Wei-Shaw/sub2api/ent/proxy"
@@ -58,6 +59,8 @@ type Client struct {
ErrorPassthroughRule *ErrorPassthroughRuleClient
// Group is the client for interacting with the Group builders.
Group *GroupClient
+ // IdempotencyRecord is the client for interacting with the IdempotencyRecord builders.
+ IdempotencyRecord *IdempotencyRecordClient
// PromoCode is the client for interacting with the PromoCode builders.
PromoCode *PromoCodeClient
// PromoCodeUsage is the client for interacting with the PromoCodeUsage builders.
@@ -102,6 +105,7 @@ func (c *Client) init() {
c.AnnouncementRead = NewAnnouncementReadClient(c.config)
c.ErrorPassthroughRule = NewErrorPassthroughRuleClient(c.config)
c.Group = NewGroupClient(c.config)
+ c.IdempotencyRecord = NewIdempotencyRecordClient(c.config)
c.PromoCode = NewPromoCodeClient(c.config)
c.PromoCodeUsage = NewPromoCodeUsageClient(c.config)
c.Proxy = NewProxyClient(c.config)
@@ -214,6 +218,7 @@ func (c *Client) Tx(ctx context.Context) (*Tx, error) {
AnnouncementRead: NewAnnouncementReadClient(cfg),
ErrorPassthroughRule: NewErrorPassthroughRuleClient(cfg),
Group: NewGroupClient(cfg),
+ IdempotencyRecord: NewIdempotencyRecordClient(cfg),
PromoCode: NewPromoCodeClient(cfg),
PromoCodeUsage: NewPromoCodeUsageClient(cfg),
Proxy: NewProxyClient(cfg),
@@ -253,6 +258,7 @@ func (c *Client) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error)
AnnouncementRead: NewAnnouncementReadClient(cfg),
ErrorPassthroughRule: NewErrorPassthroughRuleClient(cfg),
Group: NewGroupClient(cfg),
+ IdempotencyRecord: NewIdempotencyRecordClient(cfg),
PromoCode: NewPromoCodeClient(cfg),
PromoCodeUsage: NewPromoCodeUsageClient(cfg),
Proxy: NewProxyClient(cfg),
@@ -296,10 +302,10 @@ func (c *Client) Close() error {
func (c *Client) Use(hooks ...Hook) {
for _, n := range []interface{ Use(...Hook) }{
c.APIKey, c.Account, c.AccountGroup, c.Announcement, c.AnnouncementRead,
- c.ErrorPassthroughRule, c.Group, c.PromoCode, c.PromoCodeUsage, c.Proxy,
- c.RedeemCode, c.SecuritySecret, c.Setting, c.UsageCleanupTask, c.UsageLog,
- c.User, c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue,
- c.UserSubscription,
+ c.ErrorPassthroughRule, c.Group, c.IdempotencyRecord, c.PromoCode,
+ c.PromoCodeUsage, c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting,
+ c.UsageCleanupTask, c.UsageLog, c.User, c.UserAllowedGroup,
+ c.UserAttributeDefinition, c.UserAttributeValue, c.UserSubscription,
} {
n.Use(hooks...)
}
@@ -310,10 +316,10 @@ func (c *Client) Use(hooks ...Hook) {
func (c *Client) Intercept(interceptors ...Interceptor) {
for _, n := range []interface{ Intercept(...Interceptor) }{
c.APIKey, c.Account, c.AccountGroup, c.Announcement, c.AnnouncementRead,
- c.ErrorPassthroughRule, c.Group, c.PromoCode, c.PromoCodeUsage, c.Proxy,
- c.RedeemCode, c.SecuritySecret, c.Setting, c.UsageCleanupTask, c.UsageLog,
- c.User, c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue,
- c.UserSubscription,
+ c.ErrorPassthroughRule, c.Group, c.IdempotencyRecord, c.PromoCode,
+ c.PromoCodeUsage, c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting,
+ c.UsageCleanupTask, c.UsageLog, c.User, c.UserAllowedGroup,
+ c.UserAttributeDefinition, c.UserAttributeValue, c.UserSubscription,
} {
n.Intercept(interceptors...)
}
@@ -336,6 +342,8 @@ func (c *Client) Mutate(ctx context.Context, m Mutation) (Value, error) {
return c.ErrorPassthroughRule.mutate(ctx, m)
case *GroupMutation:
return c.Group.mutate(ctx, m)
+ case *IdempotencyRecordMutation:
+ return c.IdempotencyRecord.mutate(ctx, m)
case *PromoCodeMutation:
return c.PromoCode.mutate(ctx, m)
case *PromoCodeUsageMutation:
@@ -1575,6 +1583,139 @@ func (c *GroupClient) mutate(ctx context.Context, m *GroupMutation) (Value, erro
}
}
+// IdempotencyRecordClient is a client for the IdempotencyRecord schema.
+type IdempotencyRecordClient struct {
+ config
+}
+
+// NewIdempotencyRecordClient returns a client for the IdempotencyRecord from the given config.
+func NewIdempotencyRecordClient(c config) *IdempotencyRecordClient {
+ return &IdempotencyRecordClient{config: c}
+}
+
+// Use adds a list of mutation hooks to the hooks stack.
+// A call to `Use(f, g, h)` equals to `idempotencyrecord.Hooks(f(g(h())))`.
+func (c *IdempotencyRecordClient) Use(hooks ...Hook) {
+ c.hooks.IdempotencyRecord = append(c.hooks.IdempotencyRecord, hooks...)
+}
+
+// Intercept adds a list of query interceptors to the interceptors stack.
+// A call to `Intercept(f, g, h)` equals to `idempotencyrecord.Intercept(f(g(h())))`.
+func (c *IdempotencyRecordClient) Intercept(interceptors ...Interceptor) {
+ c.inters.IdempotencyRecord = append(c.inters.IdempotencyRecord, interceptors...)
+}
+
+// Create returns a builder for creating a IdempotencyRecord entity.
+func (c *IdempotencyRecordClient) Create() *IdempotencyRecordCreate {
+ mutation := newIdempotencyRecordMutation(c.config, OpCreate)
+ return &IdempotencyRecordCreate{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// CreateBulk returns a builder for creating a bulk of IdempotencyRecord entities.
+func (c *IdempotencyRecordClient) CreateBulk(builders ...*IdempotencyRecordCreate) *IdempotencyRecordCreateBulk {
+ return &IdempotencyRecordCreateBulk{config: c.config, builders: builders}
+}
+
+// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates
+// a builder and applies setFunc on it.
+func (c *IdempotencyRecordClient) MapCreateBulk(slice any, setFunc func(*IdempotencyRecordCreate, int)) *IdempotencyRecordCreateBulk {
+ rv := reflect.ValueOf(slice)
+ if rv.Kind() != reflect.Slice {
+ return &IdempotencyRecordCreateBulk{err: fmt.Errorf("calling to IdempotencyRecordClient.MapCreateBulk with wrong type %T, need slice", slice)}
+ }
+ builders := make([]*IdempotencyRecordCreate, rv.Len())
+ for i := 0; i < rv.Len(); i++ {
+ builders[i] = c.Create()
+ setFunc(builders[i], i)
+ }
+ return &IdempotencyRecordCreateBulk{config: c.config, builders: builders}
+}
+
+// Update returns an update builder for IdempotencyRecord.
+func (c *IdempotencyRecordClient) Update() *IdempotencyRecordUpdate {
+ mutation := newIdempotencyRecordMutation(c.config, OpUpdate)
+ return &IdempotencyRecordUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// UpdateOne returns an update builder for the given entity.
+func (c *IdempotencyRecordClient) UpdateOne(_m *IdempotencyRecord) *IdempotencyRecordUpdateOne {
+ mutation := newIdempotencyRecordMutation(c.config, OpUpdateOne, withIdempotencyRecord(_m))
+ return &IdempotencyRecordUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// UpdateOneID returns an update builder for the given id.
+func (c *IdempotencyRecordClient) UpdateOneID(id int64) *IdempotencyRecordUpdateOne {
+ mutation := newIdempotencyRecordMutation(c.config, OpUpdateOne, withIdempotencyRecordID(id))
+ return &IdempotencyRecordUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// Delete returns a delete builder for IdempotencyRecord.
+func (c *IdempotencyRecordClient) Delete() *IdempotencyRecordDelete {
+ mutation := newIdempotencyRecordMutation(c.config, OpDelete)
+ return &IdempotencyRecordDelete{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// DeleteOne returns a builder for deleting the given entity.
+func (c *IdempotencyRecordClient) DeleteOne(_m *IdempotencyRecord) *IdempotencyRecordDeleteOne {
+ return c.DeleteOneID(_m.ID)
+}
+
+// DeleteOneID returns a builder for deleting the given entity by its id.
+func (c *IdempotencyRecordClient) DeleteOneID(id int64) *IdempotencyRecordDeleteOne {
+ builder := c.Delete().Where(idempotencyrecord.ID(id))
+ builder.mutation.id = &id
+ builder.mutation.op = OpDeleteOne
+ return &IdempotencyRecordDeleteOne{builder}
+}
+
+// Query returns a query builder for IdempotencyRecord.
+func (c *IdempotencyRecordClient) Query() *IdempotencyRecordQuery {
+ return &IdempotencyRecordQuery{
+ config: c.config,
+ ctx: &QueryContext{Type: TypeIdempotencyRecord},
+ inters: c.Interceptors(),
+ }
+}
+
+// Get returns a IdempotencyRecord entity by its id.
+func (c *IdempotencyRecordClient) Get(ctx context.Context, id int64) (*IdempotencyRecord, error) {
+ return c.Query().Where(idempotencyrecord.ID(id)).Only(ctx)
+}
+
+// GetX is like Get, but panics if an error occurs.
+func (c *IdempotencyRecordClient) GetX(ctx context.Context, id int64) *IdempotencyRecord {
+ obj, err := c.Get(ctx, id)
+ if err != nil {
+ panic(err)
+ }
+ return obj
+}
+
+// Hooks returns the client hooks.
+func (c *IdempotencyRecordClient) Hooks() []Hook {
+ return c.hooks.IdempotencyRecord
+}
+
+// Interceptors returns the client interceptors.
+func (c *IdempotencyRecordClient) Interceptors() []Interceptor {
+ return c.inters.IdempotencyRecord
+}
+
+func (c *IdempotencyRecordClient) mutate(ctx context.Context, m *IdempotencyRecordMutation) (Value, error) {
+ switch m.Op() {
+ case OpCreate:
+ return (&IdempotencyRecordCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpUpdate:
+ return (&IdempotencyRecordUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpUpdateOne:
+ return (&IdempotencyRecordUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpDelete, OpDeleteOne:
+ return (&IdempotencyRecordDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx)
+ default:
+ return nil, fmt.Errorf("ent: unknown IdempotencyRecord mutation op: %q", m.Op())
+ }
+}
+
// PromoCodeClient is a client for the PromoCode schema.
type PromoCodeClient struct {
config
@@ -3747,15 +3888,17 @@ func (c *UserSubscriptionClient) mutate(ctx context.Context, m *UserSubscription
type (
hooks struct {
APIKey, Account, AccountGroup, Announcement, AnnouncementRead,
- ErrorPassthroughRule, Group, PromoCode, PromoCodeUsage, Proxy, RedeemCode,
- SecuritySecret, Setting, UsageCleanupTask, UsageLog, User, UserAllowedGroup,
- UserAttributeDefinition, UserAttributeValue, UserSubscription []ent.Hook
+ ErrorPassthroughRule, Group, IdempotencyRecord, PromoCode, PromoCodeUsage,
+ Proxy, RedeemCode, SecuritySecret, Setting, UsageCleanupTask, UsageLog, User,
+ UserAllowedGroup, UserAttributeDefinition, UserAttributeValue,
+ UserSubscription []ent.Hook
}
inters struct {
APIKey, Account, AccountGroup, Announcement, AnnouncementRead,
- ErrorPassthroughRule, Group, PromoCode, PromoCodeUsage, Proxy, RedeemCode,
- SecuritySecret, Setting, UsageCleanupTask, UsageLog, User, UserAllowedGroup,
- UserAttributeDefinition, UserAttributeValue, UserSubscription []ent.Interceptor
+ ErrorPassthroughRule, Group, IdempotencyRecord, PromoCode, PromoCodeUsage,
+ Proxy, RedeemCode, SecuritySecret, Setting, UsageCleanupTask, UsageLog, User,
+ UserAllowedGroup, UserAttributeDefinition, UserAttributeValue,
+ UserSubscription []ent.Interceptor
}
)
diff --git a/backend/ent/ent.go b/backend/ent/ent.go
index c4ec3387..5197e4d8 100644
--- a/backend/ent/ent.go
+++ b/backend/ent/ent.go
@@ -19,6 +19,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule"
"github.com/Wei-Shaw/sub2api/ent/group"
+ "github.com/Wei-Shaw/sub2api/ent/idempotencyrecord"
"github.com/Wei-Shaw/sub2api/ent/promocode"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
"github.com/Wei-Shaw/sub2api/ent/proxy"
@@ -99,6 +100,7 @@ func checkColumn(t, c string) error {
announcementread.Table: announcementread.ValidColumn,
errorpassthroughrule.Table: errorpassthroughrule.ValidColumn,
group.Table: group.ValidColumn,
+ idempotencyrecord.Table: idempotencyrecord.ValidColumn,
promocode.Table: promocode.ValidColumn,
promocodeusage.Table: promocodeusage.ValidColumn,
proxy.Table: proxy.ValidColumn,
diff --git a/backend/ent/group.go b/backend/ent/group.go
index 79ec5bf5..76c3cae2 100644
--- a/backend/ent/group.go
+++ b/backend/ent/group.go
@@ -60,6 +60,8 @@ type Group struct {
SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request,omitempty"`
// SoraVideoPricePerRequestHd holds the value of the "sora_video_price_per_request_hd" field.
SoraVideoPricePerRequestHd *float64 `json:"sora_video_price_per_request_hd,omitempty"`
+ // SoraStorageQuotaBytes holds the value of the "sora_storage_quota_bytes" field.
+ SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes,omitempty"`
// 是否仅允许 Claude Code 客户端
ClaudeCodeOnly bool `json:"claude_code_only,omitempty"`
// 非 Claude Code 请求降级使用的分组 ID
@@ -188,7 +190,7 @@ func (*Group) scanValues(columns []string) ([]any, error) {
values[i] = new(sql.NullBool)
case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k, group.FieldSoraImagePrice360, group.FieldSoraImagePrice540, group.FieldSoraVideoPricePerRequest, group.FieldSoraVideoPricePerRequestHd:
values[i] = new(sql.NullFloat64)
- case group.FieldID, group.FieldDefaultValidityDays, group.FieldFallbackGroupID, group.FieldFallbackGroupIDOnInvalidRequest, group.FieldSortOrder:
+ case group.FieldID, group.FieldDefaultValidityDays, group.FieldSoraStorageQuotaBytes, group.FieldFallbackGroupID, group.FieldFallbackGroupIDOnInvalidRequest, group.FieldSortOrder:
values[i] = new(sql.NullInt64)
case group.FieldName, group.FieldDescription, group.FieldStatus, group.FieldPlatform, group.FieldSubscriptionType:
values[i] = new(sql.NullString)
@@ -353,6 +355,12 @@ func (_m *Group) assignValues(columns []string, values []any) error {
_m.SoraVideoPricePerRequestHd = new(float64)
*_m.SoraVideoPricePerRequestHd = value.Float64
}
+ case group.FieldSoraStorageQuotaBytes:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field sora_storage_quota_bytes", values[i])
+ } else if value.Valid {
+ _m.SoraStorageQuotaBytes = value.Int64
+ }
case group.FieldClaudeCodeOnly:
if value, ok := values[i].(*sql.NullBool); !ok {
return fmt.Errorf("unexpected type %T for field claude_code_only", values[i])
@@ -570,6 +578,9 @@ func (_m *Group) String() string {
builder.WriteString(fmt.Sprintf("%v", *v))
}
builder.WriteString(", ")
+ builder.WriteString("sora_storage_quota_bytes=")
+ builder.WriteString(fmt.Sprintf("%v", _m.SoraStorageQuotaBytes))
+ builder.WriteString(", ")
builder.WriteString("claude_code_only=")
builder.WriteString(fmt.Sprintf("%v", _m.ClaudeCodeOnly))
builder.WriteString(", ")
diff --git a/backend/ent/group/group.go b/backend/ent/group/group.go
index 133123a1..6ac4eea1 100644
--- a/backend/ent/group/group.go
+++ b/backend/ent/group/group.go
@@ -57,6 +57,8 @@ const (
FieldSoraVideoPricePerRequest = "sora_video_price_per_request"
// FieldSoraVideoPricePerRequestHd holds the string denoting the sora_video_price_per_request_hd field in the database.
FieldSoraVideoPricePerRequestHd = "sora_video_price_per_request_hd"
+ // FieldSoraStorageQuotaBytes holds the string denoting the sora_storage_quota_bytes field in the database.
+ FieldSoraStorageQuotaBytes = "sora_storage_quota_bytes"
// FieldClaudeCodeOnly holds the string denoting the claude_code_only field in the database.
FieldClaudeCodeOnly = "claude_code_only"
// FieldFallbackGroupID holds the string denoting the fallback_group_id field in the database.
@@ -169,6 +171,7 @@ var Columns = []string{
FieldSoraImagePrice540,
FieldSoraVideoPricePerRequest,
FieldSoraVideoPricePerRequestHd,
+ FieldSoraStorageQuotaBytes,
FieldClaudeCodeOnly,
FieldFallbackGroupID,
FieldFallbackGroupIDOnInvalidRequest,
@@ -232,6 +235,8 @@ var (
SubscriptionTypeValidator func(string) error
// DefaultDefaultValidityDays holds the default value on creation for the "default_validity_days" field.
DefaultDefaultValidityDays int
+ // DefaultSoraStorageQuotaBytes holds the default value on creation for the "sora_storage_quota_bytes" field.
+ DefaultSoraStorageQuotaBytes int64
// DefaultClaudeCodeOnly holds the default value on creation for the "claude_code_only" field.
DefaultClaudeCodeOnly bool
// DefaultModelRoutingEnabled holds the default value on creation for the "model_routing_enabled" field.
@@ -357,6 +362,11 @@ func BySoraVideoPricePerRequestHd(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldSoraVideoPricePerRequestHd, opts...).ToFunc()
}
+// BySoraStorageQuotaBytes orders the results by the sora_storage_quota_bytes field.
+func BySoraStorageQuotaBytes(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldSoraStorageQuotaBytes, opts...).ToFunc()
+}
+
// ByClaudeCodeOnly orders the results by the claude_code_only field.
func ByClaudeCodeOnly(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldClaudeCodeOnly, opts...).ToFunc()
diff --git a/backend/ent/group/where.go b/backend/ent/group/where.go
index 127d4ae9..4cf65d0f 100644
--- a/backend/ent/group/where.go
+++ b/backend/ent/group/where.go
@@ -160,6 +160,11 @@ func SoraVideoPricePerRequestHd(v float64) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldSoraVideoPricePerRequestHd, v))
}
+// SoraStorageQuotaBytes applies equality check predicate on the "sora_storage_quota_bytes" field. It's identical to SoraStorageQuotaBytesEQ.
+func SoraStorageQuotaBytes(v int64) predicate.Group {
+ return predicate.Group(sql.FieldEQ(FieldSoraStorageQuotaBytes, v))
+}
+
// ClaudeCodeOnly applies equality check predicate on the "claude_code_only" field. It's identical to ClaudeCodeOnlyEQ.
func ClaudeCodeOnly(v bool) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldClaudeCodeOnly, v))
@@ -1245,6 +1250,46 @@ func SoraVideoPricePerRequestHdNotNil() predicate.Group {
return predicate.Group(sql.FieldNotNull(FieldSoraVideoPricePerRequestHd))
}
+// SoraStorageQuotaBytesEQ applies the EQ predicate on the "sora_storage_quota_bytes" field.
+func SoraStorageQuotaBytesEQ(v int64) predicate.Group {
+ return predicate.Group(sql.FieldEQ(FieldSoraStorageQuotaBytes, v))
+}
+
+// SoraStorageQuotaBytesNEQ applies the NEQ predicate on the "sora_storage_quota_bytes" field.
+func SoraStorageQuotaBytesNEQ(v int64) predicate.Group {
+ return predicate.Group(sql.FieldNEQ(FieldSoraStorageQuotaBytes, v))
+}
+
+// SoraStorageQuotaBytesIn applies the In predicate on the "sora_storage_quota_bytes" field.
+func SoraStorageQuotaBytesIn(vs ...int64) predicate.Group {
+ return predicate.Group(sql.FieldIn(FieldSoraStorageQuotaBytes, vs...))
+}
+
+// SoraStorageQuotaBytesNotIn applies the NotIn predicate on the "sora_storage_quota_bytes" field.
+func SoraStorageQuotaBytesNotIn(vs ...int64) predicate.Group {
+ return predicate.Group(sql.FieldNotIn(FieldSoraStorageQuotaBytes, vs...))
+}
+
+// SoraStorageQuotaBytesGT applies the GT predicate on the "sora_storage_quota_bytes" field.
+func SoraStorageQuotaBytesGT(v int64) predicate.Group {
+ return predicate.Group(sql.FieldGT(FieldSoraStorageQuotaBytes, v))
+}
+
+// SoraStorageQuotaBytesGTE applies the GTE predicate on the "sora_storage_quota_bytes" field.
+func SoraStorageQuotaBytesGTE(v int64) predicate.Group {
+ return predicate.Group(sql.FieldGTE(FieldSoraStorageQuotaBytes, v))
+}
+
+// SoraStorageQuotaBytesLT applies the LT predicate on the "sora_storage_quota_bytes" field.
+func SoraStorageQuotaBytesLT(v int64) predicate.Group {
+ return predicate.Group(sql.FieldLT(FieldSoraStorageQuotaBytes, v))
+}
+
+// SoraStorageQuotaBytesLTE applies the LTE predicate on the "sora_storage_quota_bytes" field.
+func SoraStorageQuotaBytesLTE(v int64) predicate.Group {
+ return predicate.Group(sql.FieldLTE(FieldSoraStorageQuotaBytes, v))
+}
+
// ClaudeCodeOnlyEQ applies the EQ predicate on the "claude_code_only" field.
func ClaudeCodeOnlyEQ(v bool) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldClaudeCodeOnly, v))
diff --git a/backend/ent/group_create.go b/backend/ent/group_create.go
index 4416516b..0ce5f959 100644
--- a/backend/ent/group_create.go
+++ b/backend/ent/group_create.go
@@ -314,6 +314,20 @@ func (_c *GroupCreate) SetNillableSoraVideoPricePerRequestHd(v *float64) *GroupC
return _c
}
+// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
+func (_c *GroupCreate) SetSoraStorageQuotaBytes(v int64) *GroupCreate {
+ _c.mutation.SetSoraStorageQuotaBytes(v)
+ return _c
+}
+
+// SetNillableSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field if the given value is not nil.
+func (_c *GroupCreate) SetNillableSoraStorageQuotaBytes(v *int64) *GroupCreate {
+ if v != nil {
+ _c.SetSoraStorageQuotaBytes(*v)
+ }
+ return _c
+}
+
// SetClaudeCodeOnly sets the "claude_code_only" field.
func (_c *GroupCreate) SetClaudeCodeOnly(v bool) *GroupCreate {
_c.mutation.SetClaudeCodeOnly(v)
@@ -575,6 +589,10 @@ func (_c *GroupCreate) defaults() error {
v := group.DefaultDefaultValidityDays
_c.mutation.SetDefaultValidityDays(v)
}
+ if _, ok := _c.mutation.SoraStorageQuotaBytes(); !ok {
+ v := group.DefaultSoraStorageQuotaBytes
+ _c.mutation.SetSoraStorageQuotaBytes(v)
+ }
if _, ok := _c.mutation.ClaudeCodeOnly(); !ok {
v := group.DefaultClaudeCodeOnly
_c.mutation.SetClaudeCodeOnly(v)
@@ -647,6 +665,9 @@ func (_c *GroupCreate) check() error {
if _, ok := _c.mutation.DefaultValidityDays(); !ok {
return &ValidationError{Name: "default_validity_days", err: errors.New(`ent: missing required field "Group.default_validity_days"`)}
}
+ if _, ok := _c.mutation.SoraStorageQuotaBytes(); !ok {
+ return &ValidationError{Name: "sora_storage_quota_bytes", err: errors.New(`ent: missing required field "Group.sora_storage_quota_bytes"`)}
+ }
if _, ok := _c.mutation.ClaudeCodeOnly(); !ok {
return &ValidationError{Name: "claude_code_only", err: errors.New(`ent: missing required field "Group.claude_code_only"`)}
}
@@ -773,6 +794,10 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) {
_spec.SetField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64, value)
_node.SoraVideoPricePerRequestHd = &value
}
+ if value, ok := _c.mutation.SoraStorageQuotaBytes(); ok {
+ _spec.SetField(group.FieldSoraStorageQuotaBytes, field.TypeInt64, value)
+ _node.SoraStorageQuotaBytes = value
+ }
if value, ok := _c.mutation.ClaudeCodeOnly(); ok {
_spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value)
_node.ClaudeCodeOnly = value
@@ -1345,6 +1370,24 @@ func (u *GroupUpsert) ClearSoraVideoPricePerRequestHd() *GroupUpsert {
return u
}
+// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
+func (u *GroupUpsert) SetSoraStorageQuotaBytes(v int64) *GroupUpsert {
+ u.Set(group.FieldSoraStorageQuotaBytes, v)
+ return u
+}
+
+// UpdateSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field to the value that was provided on create.
+func (u *GroupUpsert) UpdateSoraStorageQuotaBytes() *GroupUpsert {
+ u.SetExcluded(group.FieldSoraStorageQuotaBytes)
+ return u
+}
+
+// AddSoraStorageQuotaBytes adds v to the "sora_storage_quota_bytes" field.
+func (u *GroupUpsert) AddSoraStorageQuotaBytes(v int64) *GroupUpsert {
+ u.Add(group.FieldSoraStorageQuotaBytes, v)
+ return u
+}
+
// SetClaudeCodeOnly sets the "claude_code_only" field.
func (u *GroupUpsert) SetClaudeCodeOnly(v bool) *GroupUpsert {
u.Set(group.FieldClaudeCodeOnly, v)
@@ -1970,6 +2013,27 @@ func (u *GroupUpsertOne) ClearSoraVideoPricePerRequestHd() *GroupUpsertOne {
})
}
+// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
+func (u *GroupUpsertOne) SetSoraStorageQuotaBytes(v int64) *GroupUpsertOne {
+ return u.Update(func(s *GroupUpsert) {
+ s.SetSoraStorageQuotaBytes(v)
+ })
+}
+
+// AddSoraStorageQuotaBytes adds v to the "sora_storage_quota_bytes" field.
+func (u *GroupUpsertOne) AddSoraStorageQuotaBytes(v int64) *GroupUpsertOne {
+ return u.Update(func(s *GroupUpsert) {
+ s.AddSoraStorageQuotaBytes(v)
+ })
+}
+
+// UpdateSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field to the value that was provided on create.
+func (u *GroupUpsertOne) UpdateSoraStorageQuotaBytes() *GroupUpsertOne {
+ return u.Update(func(s *GroupUpsert) {
+ s.UpdateSoraStorageQuotaBytes()
+ })
+}
+
// SetClaudeCodeOnly sets the "claude_code_only" field.
func (u *GroupUpsertOne) SetClaudeCodeOnly(v bool) *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
@@ -2783,6 +2847,27 @@ func (u *GroupUpsertBulk) ClearSoraVideoPricePerRequestHd() *GroupUpsertBulk {
})
}
+// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
+func (u *GroupUpsertBulk) SetSoraStorageQuotaBytes(v int64) *GroupUpsertBulk {
+ return u.Update(func(s *GroupUpsert) {
+ s.SetSoraStorageQuotaBytes(v)
+ })
+}
+
+// AddSoraStorageQuotaBytes adds v to the "sora_storage_quota_bytes" field.
+func (u *GroupUpsertBulk) AddSoraStorageQuotaBytes(v int64) *GroupUpsertBulk {
+ return u.Update(func(s *GroupUpsert) {
+ s.AddSoraStorageQuotaBytes(v)
+ })
+}
+
+// UpdateSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field to the value that was provided on create.
+func (u *GroupUpsertBulk) UpdateSoraStorageQuotaBytes() *GroupUpsertBulk {
+ return u.Update(func(s *GroupUpsert) {
+ s.UpdateSoraStorageQuotaBytes()
+ })
+}
+
// SetClaudeCodeOnly sets the "claude_code_only" field.
func (u *GroupUpsertBulk) SetClaudeCodeOnly(v bool) *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
diff --git a/backend/ent/group_update.go b/backend/ent/group_update.go
index db510e05..85575292 100644
--- a/backend/ent/group_update.go
+++ b/backend/ent/group_update.go
@@ -463,6 +463,27 @@ func (_u *GroupUpdate) ClearSoraVideoPricePerRequestHd() *GroupUpdate {
return _u
}
+// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
+func (_u *GroupUpdate) SetSoraStorageQuotaBytes(v int64) *GroupUpdate {
+ _u.mutation.ResetSoraStorageQuotaBytes()
+ _u.mutation.SetSoraStorageQuotaBytes(v)
+ return _u
+}
+
+// SetNillableSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field if the given value is not nil.
+func (_u *GroupUpdate) SetNillableSoraStorageQuotaBytes(v *int64) *GroupUpdate {
+ if v != nil {
+ _u.SetSoraStorageQuotaBytes(*v)
+ }
+ return _u
+}
+
+// AddSoraStorageQuotaBytes adds value to the "sora_storage_quota_bytes" field.
+func (_u *GroupUpdate) AddSoraStorageQuotaBytes(v int64) *GroupUpdate {
+ _u.mutation.AddSoraStorageQuotaBytes(v)
+ return _u
+}
+
// SetClaudeCodeOnly sets the "claude_code_only" field.
func (_u *GroupUpdate) SetClaudeCodeOnly(v bool) *GroupUpdate {
_u.mutation.SetClaudeCodeOnly(v)
@@ -1036,6 +1057,12 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if _u.mutation.SoraVideoPricePerRequestHdCleared() {
_spec.ClearField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64)
}
+ if value, ok := _u.mutation.SoraStorageQuotaBytes(); ok {
+ _spec.SetField(group.FieldSoraStorageQuotaBytes, field.TypeInt64, value)
+ }
+ if value, ok := _u.mutation.AddedSoraStorageQuotaBytes(); ok {
+ _spec.AddField(group.FieldSoraStorageQuotaBytes, field.TypeInt64, value)
+ }
if value, ok := _u.mutation.ClaudeCodeOnly(); ok {
_spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value)
}
@@ -1825,6 +1852,27 @@ func (_u *GroupUpdateOne) ClearSoraVideoPricePerRequestHd() *GroupUpdateOne {
return _u
}
+// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
+func (_u *GroupUpdateOne) SetSoraStorageQuotaBytes(v int64) *GroupUpdateOne {
+ _u.mutation.ResetSoraStorageQuotaBytes()
+ _u.mutation.SetSoraStorageQuotaBytes(v)
+ return _u
+}
+
+// SetNillableSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field if the given value is not nil.
+func (_u *GroupUpdateOne) SetNillableSoraStorageQuotaBytes(v *int64) *GroupUpdateOne {
+ if v != nil {
+ _u.SetSoraStorageQuotaBytes(*v)
+ }
+ return _u
+}
+
+// AddSoraStorageQuotaBytes adds value to the "sora_storage_quota_bytes" field.
+func (_u *GroupUpdateOne) AddSoraStorageQuotaBytes(v int64) *GroupUpdateOne {
+ _u.mutation.AddSoraStorageQuotaBytes(v)
+ return _u
+}
+
// SetClaudeCodeOnly sets the "claude_code_only" field.
func (_u *GroupUpdateOne) SetClaudeCodeOnly(v bool) *GroupUpdateOne {
_u.mutation.SetClaudeCodeOnly(v)
@@ -2428,6 +2476,12 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error)
if _u.mutation.SoraVideoPricePerRequestHdCleared() {
_spec.ClearField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64)
}
+ if value, ok := _u.mutation.SoraStorageQuotaBytes(); ok {
+ _spec.SetField(group.FieldSoraStorageQuotaBytes, field.TypeInt64, value)
+ }
+ if value, ok := _u.mutation.AddedSoraStorageQuotaBytes(); ok {
+ _spec.AddField(group.FieldSoraStorageQuotaBytes, field.TypeInt64, value)
+ }
if value, ok := _u.mutation.ClaudeCodeOnly(); ok {
_spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value)
}
diff --git a/backend/ent/hook/hook.go b/backend/ent/hook/hook.go
index aff9caa0..49d7f3c5 100644
--- a/backend/ent/hook/hook.go
+++ b/backend/ent/hook/hook.go
@@ -93,6 +93,18 @@ func (f GroupFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.GroupMutation", m)
}
+// The IdempotencyRecordFunc type is an adapter to allow the use of ordinary
+// function as IdempotencyRecord mutator.
+type IdempotencyRecordFunc func(context.Context, *ent.IdempotencyRecordMutation) (ent.Value, error)
+
+// Mutate calls f(ctx, m).
+func (f IdempotencyRecordFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) {
+ if mv, ok := m.(*ent.IdempotencyRecordMutation); ok {
+ return f(ctx, mv)
+ }
+ return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.IdempotencyRecordMutation", m)
+}
+
// The PromoCodeFunc type is an adapter to allow the use of ordinary
// function as PromoCode mutator.
type PromoCodeFunc func(context.Context, *ent.PromoCodeMutation) (ent.Value, error)
diff --git a/backend/ent/idempotencyrecord.go b/backend/ent/idempotencyrecord.go
new file mode 100644
index 00000000..ab120f8f
--- /dev/null
+++ b/backend/ent/idempotencyrecord.go
@@ -0,0 +1,228 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "fmt"
+ "strings"
+ "time"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect/sql"
+ "github.com/Wei-Shaw/sub2api/ent/idempotencyrecord"
+)
+
+// IdempotencyRecord is the model entity for the IdempotencyRecord schema.
+type IdempotencyRecord struct {
+ config `json:"-"`
+ // ID of the ent.
+ ID int64 `json:"id,omitempty"`
+ // CreatedAt holds the value of the "created_at" field.
+ CreatedAt time.Time `json:"created_at,omitempty"`
+ // UpdatedAt holds the value of the "updated_at" field.
+ UpdatedAt time.Time `json:"updated_at,omitempty"`
+ // Scope holds the value of the "scope" field.
+ Scope string `json:"scope,omitempty"`
+ // IdempotencyKeyHash holds the value of the "idempotency_key_hash" field.
+ IdempotencyKeyHash string `json:"idempotency_key_hash,omitempty"`
+ // RequestFingerprint holds the value of the "request_fingerprint" field.
+ RequestFingerprint string `json:"request_fingerprint,omitempty"`
+ // Status holds the value of the "status" field.
+ Status string `json:"status,omitempty"`
+ // ResponseStatus holds the value of the "response_status" field.
+ ResponseStatus *int `json:"response_status,omitempty"`
+ // ResponseBody holds the value of the "response_body" field.
+ ResponseBody *string `json:"response_body,omitempty"`
+ // ErrorReason holds the value of the "error_reason" field.
+ ErrorReason *string `json:"error_reason,omitempty"`
+ // LockedUntil holds the value of the "locked_until" field.
+ LockedUntil *time.Time `json:"locked_until,omitempty"`
+ // ExpiresAt holds the value of the "expires_at" field.
+ ExpiresAt time.Time `json:"expires_at,omitempty"`
+ selectValues sql.SelectValues
+}
+
+// scanValues returns the types for scanning values from sql.Rows.
+func (*IdempotencyRecord) scanValues(columns []string) ([]any, error) {
+ values := make([]any, len(columns))
+ for i := range columns {
+ switch columns[i] {
+ case idempotencyrecord.FieldID, idempotencyrecord.FieldResponseStatus:
+ values[i] = new(sql.NullInt64)
+ case idempotencyrecord.FieldScope, idempotencyrecord.FieldIdempotencyKeyHash, idempotencyrecord.FieldRequestFingerprint, idempotencyrecord.FieldStatus, idempotencyrecord.FieldResponseBody, idempotencyrecord.FieldErrorReason:
+ values[i] = new(sql.NullString)
+ case idempotencyrecord.FieldCreatedAt, idempotencyrecord.FieldUpdatedAt, idempotencyrecord.FieldLockedUntil, idempotencyrecord.FieldExpiresAt:
+ values[i] = new(sql.NullTime)
+ default:
+ values[i] = new(sql.UnknownType)
+ }
+ }
+ return values, nil
+}
+
+// assignValues assigns the values that were returned from sql.Rows (after scanning)
+// to the IdempotencyRecord fields.
+func (_m *IdempotencyRecord) assignValues(columns []string, values []any) error {
+ if m, n := len(values), len(columns); m < n {
+ return fmt.Errorf("mismatch number of scan values: %d != %d", m, n)
+ }
+ for i := range columns {
+ switch columns[i] {
+ case idempotencyrecord.FieldID:
+ value, ok := values[i].(*sql.NullInt64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field id", value)
+ }
+ _m.ID = int64(value.Int64)
+ case idempotencyrecord.FieldCreatedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field created_at", values[i])
+ } else if value.Valid {
+ _m.CreatedAt = value.Time
+ }
+ case idempotencyrecord.FieldUpdatedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field updated_at", values[i])
+ } else if value.Valid {
+ _m.UpdatedAt = value.Time
+ }
+ case idempotencyrecord.FieldScope:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field scope", values[i])
+ } else if value.Valid {
+ _m.Scope = value.String
+ }
+ case idempotencyrecord.FieldIdempotencyKeyHash:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field idempotency_key_hash", values[i])
+ } else if value.Valid {
+ _m.IdempotencyKeyHash = value.String
+ }
+ case idempotencyrecord.FieldRequestFingerprint:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field request_fingerprint", values[i])
+ } else if value.Valid {
+ _m.RequestFingerprint = value.String
+ }
+ case idempotencyrecord.FieldStatus:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field status", values[i])
+ } else if value.Valid {
+ _m.Status = value.String
+ }
+ case idempotencyrecord.FieldResponseStatus:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field response_status", values[i])
+ } else if value.Valid {
+ _m.ResponseStatus = new(int)
+ *_m.ResponseStatus = int(value.Int64)
+ }
+ case idempotencyrecord.FieldResponseBody:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field response_body", values[i])
+ } else if value.Valid {
+ _m.ResponseBody = new(string)
+ *_m.ResponseBody = value.String
+ }
+ case idempotencyrecord.FieldErrorReason:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field error_reason", values[i])
+ } else if value.Valid {
+ _m.ErrorReason = new(string)
+ *_m.ErrorReason = value.String
+ }
+ case idempotencyrecord.FieldLockedUntil:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field locked_until", values[i])
+ } else if value.Valid {
+ _m.LockedUntil = new(time.Time)
+ *_m.LockedUntil = value.Time
+ }
+ case idempotencyrecord.FieldExpiresAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field expires_at", values[i])
+ } else if value.Valid {
+ _m.ExpiresAt = value.Time
+ }
+ default:
+ _m.selectValues.Set(columns[i], values[i])
+ }
+ }
+ return nil
+}
+
+// Value returns the ent.Value that was dynamically selected and assigned to the IdempotencyRecord.
+// This includes values selected through modifiers, order, etc.
+func (_m *IdempotencyRecord) Value(name string) (ent.Value, error) {
+ return _m.selectValues.Get(name)
+}
+
+// Update returns a builder for updating this IdempotencyRecord.
+// Note that you need to call IdempotencyRecord.Unwrap() before calling this method if this IdempotencyRecord
+// was returned from a transaction, and the transaction was committed or rolled back.
+func (_m *IdempotencyRecord) Update() *IdempotencyRecordUpdateOne {
+ return NewIdempotencyRecordClient(_m.config).UpdateOne(_m)
+}
+
+// Unwrap unwraps the IdempotencyRecord entity that was returned from a transaction after it was closed,
+// so that all future queries will be executed through the driver which created the transaction.
+func (_m *IdempotencyRecord) Unwrap() *IdempotencyRecord {
+ _tx, ok := _m.config.driver.(*txDriver)
+ if !ok {
+ panic("ent: IdempotencyRecord is not a transactional entity")
+ }
+ _m.config.driver = _tx.drv
+ return _m
+}
+
+// String implements the fmt.Stringer.
+func (_m *IdempotencyRecord) String() string {
+ var builder strings.Builder
+ builder.WriteString("IdempotencyRecord(")
+ builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID))
+ builder.WriteString("created_at=")
+ builder.WriteString(_m.CreatedAt.Format(time.ANSIC))
+ builder.WriteString(", ")
+ builder.WriteString("updated_at=")
+ builder.WriteString(_m.UpdatedAt.Format(time.ANSIC))
+ builder.WriteString(", ")
+ builder.WriteString("scope=")
+ builder.WriteString(_m.Scope)
+ builder.WriteString(", ")
+ builder.WriteString("idempotency_key_hash=")
+ builder.WriteString(_m.IdempotencyKeyHash)
+ builder.WriteString(", ")
+ builder.WriteString("request_fingerprint=")
+ builder.WriteString(_m.RequestFingerprint)
+ builder.WriteString(", ")
+ builder.WriteString("status=")
+ builder.WriteString(_m.Status)
+ builder.WriteString(", ")
+ if v := _m.ResponseStatus; v != nil {
+ builder.WriteString("response_status=")
+ builder.WriteString(fmt.Sprintf("%v", *v))
+ }
+ builder.WriteString(", ")
+ if v := _m.ResponseBody; v != nil {
+ builder.WriteString("response_body=")
+ builder.WriteString(*v)
+ }
+ builder.WriteString(", ")
+ if v := _m.ErrorReason; v != nil {
+ builder.WriteString("error_reason=")
+ builder.WriteString(*v)
+ }
+ builder.WriteString(", ")
+ if v := _m.LockedUntil; v != nil {
+ builder.WriteString("locked_until=")
+ builder.WriteString(v.Format(time.ANSIC))
+ }
+ builder.WriteString(", ")
+ builder.WriteString("expires_at=")
+ builder.WriteString(_m.ExpiresAt.Format(time.ANSIC))
+ builder.WriteByte(')')
+ return builder.String()
+}
+
+// IdempotencyRecords is a parsable slice of IdempotencyRecord.
+type IdempotencyRecords []*IdempotencyRecord
diff --git a/backend/ent/idempotencyrecord/idempotencyrecord.go b/backend/ent/idempotencyrecord/idempotencyrecord.go
new file mode 100644
index 00000000..d9686f60
--- /dev/null
+++ b/backend/ent/idempotencyrecord/idempotencyrecord.go
@@ -0,0 +1,148 @@
+// Code generated by ent, DO NOT EDIT.
+
+package idempotencyrecord
+
+import (
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+)
+
+const (
+ // Label holds the string label denoting the idempotencyrecord type in the database.
+ Label = "idempotency_record"
+ // FieldID holds the string denoting the id field in the database.
+ FieldID = "id"
+ // FieldCreatedAt holds the string denoting the created_at field in the database.
+ FieldCreatedAt = "created_at"
+ // FieldUpdatedAt holds the string denoting the updated_at field in the database.
+ FieldUpdatedAt = "updated_at"
+ // FieldScope holds the string denoting the scope field in the database.
+ FieldScope = "scope"
+ // FieldIdempotencyKeyHash holds the string denoting the idempotency_key_hash field in the database.
+ FieldIdempotencyKeyHash = "idempotency_key_hash"
+ // FieldRequestFingerprint holds the string denoting the request_fingerprint field in the database.
+ FieldRequestFingerprint = "request_fingerprint"
+ // FieldStatus holds the string denoting the status field in the database.
+ FieldStatus = "status"
+ // FieldResponseStatus holds the string denoting the response_status field in the database.
+ FieldResponseStatus = "response_status"
+ // FieldResponseBody holds the string denoting the response_body field in the database.
+ FieldResponseBody = "response_body"
+ // FieldErrorReason holds the string denoting the error_reason field in the database.
+ FieldErrorReason = "error_reason"
+ // FieldLockedUntil holds the string denoting the locked_until field in the database.
+ FieldLockedUntil = "locked_until"
+ // FieldExpiresAt holds the string denoting the expires_at field in the database.
+ FieldExpiresAt = "expires_at"
+ // Table holds the table name of the idempotencyrecord in the database.
+ Table = "idempotency_records"
+)
+
+// Columns holds all SQL columns for idempotencyrecord fields.
+var Columns = []string{
+ FieldID,
+ FieldCreatedAt,
+ FieldUpdatedAt,
+ FieldScope,
+ FieldIdempotencyKeyHash,
+ FieldRequestFingerprint,
+ FieldStatus,
+ FieldResponseStatus,
+ FieldResponseBody,
+ FieldErrorReason,
+ FieldLockedUntil,
+ FieldExpiresAt,
+}
+
+// ValidColumn reports if the column name is valid (part of the table columns).
+func ValidColumn(column string) bool {
+ for i := range Columns {
+ if column == Columns[i] {
+ return true
+ }
+ }
+ return false
+}
+
+var (
+ // DefaultCreatedAt holds the default value on creation for the "created_at" field.
+ DefaultCreatedAt func() time.Time
+ // DefaultUpdatedAt holds the default value on creation for the "updated_at" field.
+ DefaultUpdatedAt func() time.Time
+ // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field.
+ UpdateDefaultUpdatedAt func() time.Time
+ // ScopeValidator is a validator for the "scope" field. It is called by the builders before save.
+ ScopeValidator func(string) error
+ // IdempotencyKeyHashValidator is a validator for the "idempotency_key_hash" field. It is called by the builders before save.
+ IdempotencyKeyHashValidator func(string) error
+ // RequestFingerprintValidator is a validator for the "request_fingerprint" field. It is called by the builders before save.
+ RequestFingerprintValidator func(string) error
+ // StatusValidator is a validator for the "status" field. It is called by the builders before save.
+ StatusValidator func(string) error
+ // ErrorReasonValidator is a validator for the "error_reason" field. It is called by the builders before save.
+ ErrorReasonValidator func(string) error
+)
+
+// OrderOption defines the ordering options for the IdempotencyRecord queries.
+type OrderOption func(*sql.Selector)
+
+// ByID orders the results by the id field.
+func ByID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldID, opts...).ToFunc()
+}
+
+// ByCreatedAt orders the results by the created_at field.
+func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()
+}
+
+// ByUpdatedAt orders the results by the updated_at field.
+func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc()
+}
+
+// ByScope orders the results by the scope field.
+func ByScope(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldScope, opts...).ToFunc()
+}
+
+// ByIdempotencyKeyHash orders the results by the idempotency_key_hash field.
+func ByIdempotencyKeyHash(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldIdempotencyKeyHash, opts...).ToFunc()
+}
+
+// ByRequestFingerprint orders the results by the request_fingerprint field.
+func ByRequestFingerprint(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldRequestFingerprint, opts...).ToFunc()
+}
+
+// ByStatus orders the results by the status field.
+func ByStatus(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldStatus, opts...).ToFunc()
+}
+
+// ByResponseStatus orders the results by the response_status field.
+func ByResponseStatus(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldResponseStatus, opts...).ToFunc()
+}
+
+// ByResponseBody orders the results by the response_body field.
+func ByResponseBody(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldResponseBody, opts...).ToFunc()
+}
+
+// ByErrorReason orders the results by the error_reason field.
+func ByErrorReason(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldErrorReason, opts...).ToFunc()
+}
+
+// ByLockedUntil orders the results by the locked_until field.
+func ByLockedUntil(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldLockedUntil, opts...).ToFunc()
+}
+
+// ByExpiresAt orders the results by the expires_at field.
+func ByExpiresAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldExpiresAt, opts...).ToFunc()
+}
diff --git a/backend/ent/idempotencyrecord/where.go b/backend/ent/idempotencyrecord/where.go
new file mode 100644
index 00000000..c3d8d9d5
--- /dev/null
+++ b/backend/ent/idempotencyrecord/where.go
@@ -0,0 +1,755 @@
+// Code generated by ent, DO NOT EDIT.
+
+package idempotencyrecord
+
+import (
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// ID filters vertices based on their ID field.
+func ID(id int64) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldEQ(FieldID, id))
+}
+
+// IDEQ applies the EQ predicate on the ID field.
+func IDEQ(id int64) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldEQ(FieldID, id))
+}
+
+// IDNEQ applies the NEQ predicate on the ID field.
+func IDNEQ(id int64) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldNEQ(FieldID, id))
+}
+
+// IDIn applies the In predicate on the ID field.
+func IDIn(ids ...int64) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldIn(FieldID, ids...))
+}
+
+// IDNotIn applies the NotIn predicate on the ID field.
+func IDNotIn(ids ...int64) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldNotIn(FieldID, ids...))
+}
+
+// IDGT applies the GT predicate on the ID field.
+func IDGT(id int64) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldGT(FieldID, id))
+}
+
+// IDGTE applies the GTE predicate on the ID field.
+func IDGTE(id int64) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldGTE(FieldID, id))
+}
+
+// IDLT applies the LT predicate on the ID field.
+func IDLT(id int64) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldLT(FieldID, id))
+}
+
+// IDLTE applies the LTE predicate on the ID field.
+func IDLTE(id int64) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldLTE(FieldID, id))
+}
+
+// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ.
+func CreatedAt(v time.Time) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldEQ(FieldCreatedAt, v))
+}
+
+// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ.
+func UpdatedAt(v time.Time) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldEQ(FieldUpdatedAt, v))
+}
+
+// Scope applies equality check predicate on the "scope" field. It's identical to ScopeEQ.
+func Scope(v string) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldEQ(FieldScope, v))
+}
+
+// IdempotencyKeyHash applies equality check predicate on the "idempotency_key_hash" field. It's identical to IdempotencyKeyHashEQ.
+func IdempotencyKeyHash(v string) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldEQ(FieldIdempotencyKeyHash, v))
+}
+
+// RequestFingerprint applies equality check predicate on the "request_fingerprint" field. It's identical to RequestFingerprintEQ.
+func RequestFingerprint(v string) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldEQ(FieldRequestFingerprint, v))
+}
+
+// Status applies equality check predicate on the "status" field. It's identical to StatusEQ.
+func Status(v string) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldEQ(FieldStatus, v))
+}
+
+// ResponseStatus applies equality check predicate on the "response_status" field. It's identical to ResponseStatusEQ.
+func ResponseStatus(v int) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldEQ(FieldResponseStatus, v))
+}
+
+// ResponseBody applies equality check predicate on the "response_body" field. It's identical to ResponseBodyEQ.
+func ResponseBody(v string) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldEQ(FieldResponseBody, v))
+}
+
+// ErrorReason applies equality check predicate on the "error_reason" field. It's identical to ErrorReasonEQ.
+func ErrorReason(v string) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldEQ(FieldErrorReason, v))
+}
+
+// LockedUntil applies equality check predicate on the "locked_until" field. It's identical to LockedUntilEQ.
+func LockedUntil(v time.Time) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldEQ(FieldLockedUntil, v))
+}
+
+// ExpiresAt applies equality check predicate on the "expires_at" field. It's identical to ExpiresAtEQ.
+func ExpiresAt(v time.Time) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldEQ(FieldExpiresAt, v))
+}
+
+// CreatedAtEQ applies the EQ predicate on the "created_at" field.
+func CreatedAtEQ(v time.Time) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldEQ(FieldCreatedAt, v))
+}
+
+// CreatedAtNEQ applies the NEQ predicate on the "created_at" field.
+func CreatedAtNEQ(v time.Time) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldNEQ(FieldCreatedAt, v))
+}
+
+// CreatedAtIn applies the In predicate on the "created_at" field.
+func CreatedAtIn(vs ...time.Time) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldIn(FieldCreatedAt, vs...))
+}
+
+// CreatedAtNotIn applies the NotIn predicate on the "created_at" field.
+func CreatedAtNotIn(vs ...time.Time) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldNotIn(FieldCreatedAt, vs...))
+}
+
+// CreatedAtGT applies the GT predicate on the "created_at" field.
+func CreatedAtGT(v time.Time) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldGT(FieldCreatedAt, v))
+}
+
+// CreatedAtGTE applies the GTE predicate on the "created_at" field.
+func CreatedAtGTE(v time.Time) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldGTE(FieldCreatedAt, v))
+}
+
+// CreatedAtLT applies the LT predicate on the "created_at" field.
+func CreatedAtLT(v time.Time) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldLT(FieldCreatedAt, v))
+}
+
+// CreatedAtLTE applies the LTE predicate on the "created_at" field.
+func CreatedAtLTE(v time.Time) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldLTE(FieldCreatedAt, v))
+}
+
+// UpdatedAtEQ applies the EQ predicate on the "updated_at" field.
+func UpdatedAtEQ(v time.Time) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldEQ(FieldUpdatedAt, v))
+}
+
+// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field.
+func UpdatedAtNEQ(v time.Time) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldNEQ(FieldUpdatedAt, v))
+}
+
+// UpdatedAtIn applies the In predicate on the "updated_at" field.
+func UpdatedAtIn(vs ...time.Time) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldIn(FieldUpdatedAt, vs...))
+}
+
+// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field.
+func UpdatedAtNotIn(vs ...time.Time) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldNotIn(FieldUpdatedAt, vs...))
+}
+
+// UpdatedAtGT applies the GT predicate on the "updated_at" field.
+func UpdatedAtGT(v time.Time) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldGT(FieldUpdatedAt, v))
+}
+
+// UpdatedAtGTE applies the GTE predicate on the "updated_at" field.
+func UpdatedAtGTE(v time.Time) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldGTE(FieldUpdatedAt, v))
+}
+
+// UpdatedAtLT applies the LT predicate on the "updated_at" field.
+func UpdatedAtLT(v time.Time) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldLT(FieldUpdatedAt, v))
+}
+
+// UpdatedAtLTE applies the LTE predicate on the "updated_at" field.
+func UpdatedAtLTE(v time.Time) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldLTE(FieldUpdatedAt, v))
+}
+
+// ScopeEQ applies the EQ predicate on the "scope" field.
+func ScopeEQ(v string) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldEQ(FieldScope, v))
+}
+
+// ScopeNEQ applies the NEQ predicate on the "scope" field.
+func ScopeNEQ(v string) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldNEQ(FieldScope, v))
+}
+
+// ScopeIn applies the In predicate on the "scope" field.
+func ScopeIn(vs ...string) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldIn(FieldScope, vs...))
+}
+
+// ScopeNotIn applies the NotIn predicate on the "scope" field.
+func ScopeNotIn(vs ...string) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldNotIn(FieldScope, vs...))
+}
+
+// ScopeGT applies the GT predicate on the "scope" field.
+func ScopeGT(v string) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldGT(FieldScope, v))
+}
+
+// ScopeGTE applies the GTE predicate on the "scope" field.
+func ScopeGTE(v string) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldGTE(FieldScope, v))
+}
+
+// ScopeLT applies the LT predicate on the "scope" field.
+func ScopeLT(v string) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldLT(FieldScope, v))
+}
+
+// ScopeLTE applies the LTE predicate on the "scope" field.
+func ScopeLTE(v string) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldLTE(FieldScope, v))
+}
+
+// ScopeContains applies the Contains predicate on the "scope" field.
+func ScopeContains(v string) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldContains(FieldScope, v))
+}
+
+// ScopeHasPrefix applies the HasPrefix predicate on the "scope" field.
+func ScopeHasPrefix(v string) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldHasPrefix(FieldScope, v))
+}
+
+// ScopeHasSuffix applies the HasSuffix predicate on the "scope" field.
+func ScopeHasSuffix(v string) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldHasSuffix(FieldScope, v))
+}
+
+// ScopeEqualFold applies the EqualFold predicate on the "scope" field.
+func ScopeEqualFold(v string) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldEqualFold(FieldScope, v))
+}
+
+// ScopeContainsFold applies the ContainsFold predicate on the "scope" field.
+func ScopeContainsFold(v string) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldContainsFold(FieldScope, v))
+}
+
+// IdempotencyKeyHashEQ applies the EQ predicate on the "idempotency_key_hash" field.
+func IdempotencyKeyHashEQ(v string) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldEQ(FieldIdempotencyKeyHash, v))
+}
+
+// IdempotencyKeyHashNEQ applies the NEQ predicate on the "idempotency_key_hash" field.
+func IdempotencyKeyHashNEQ(v string) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldNEQ(FieldIdempotencyKeyHash, v))
+}
+
+// IdempotencyKeyHashIn applies the In predicate on the "idempotency_key_hash" field.
+func IdempotencyKeyHashIn(vs ...string) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldIn(FieldIdempotencyKeyHash, vs...))
+}
+
+// IdempotencyKeyHashNotIn applies the NotIn predicate on the "idempotency_key_hash" field.
+func IdempotencyKeyHashNotIn(vs ...string) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldNotIn(FieldIdempotencyKeyHash, vs...))
+}
+
+// IdempotencyKeyHashGT applies the GT predicate on the "idempotency_key_hash" field.
+func IdempotencyKeyHashGT(v string) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldGT(FieldIdempotencyKeyHash, v))
+}
+
+// IdempotencyKeyHashGTE applies the GTE predicate on the "idempotency_key_hash" field.
+func IdempotencyKeyHashGTE(v string) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldGTE(FieldIdempotencyKeyHash, v))
+}
+
+// IdempotencyKeyHashLT applies the LT predicate on the "idempotency_key_hash" field.
+func IdempotencyKeyHashLT(v string) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldLT(FieldIdempotencyKeyHash, v))
+}
+
+// IdempotencyKeyHashLTE applies the LTE predicate on the "idempotency_key_hash" field.
+func IdempotencyKeyHashLTE(v string) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldLTE(FieldIdempotencyKeyHash, v))
+}
+
+// IdempotencyKeyHashContains applies the Contains predicate on the "idempotency_key_hash" field.
+func IdempotencyKeyHashContains(v string) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldContains(FieldIdempotencyKeyHash, v))
+}
+
+// IdempotencyKeyHashHasPrefix applies the HasPrefix predicate on the "idempotency_key_hash" field.
+func IdempotencyKeyHashHasPrefix(v string) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldHasPrefix(FieldIdempotencyKeyHash, v))
+}
+
+// IdempotencyKeyHashHasSuffix applies the HasSuffix predicate on the "idempotency_key_hash" field.
+func IdempotencyKeyHashHasSuffix(v string) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldHasSuffix(FieldIdempotencyKeyHash, v))
+}
+
+// IdempotencyKeyHashEqualFold applies the EqualFold predicate on the "idempotency_key_hash" field.
+func IdempotencyKeyHashEqualFold(v string) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldEqualFold(FieldIdempotencyKeyHash, v))
+}
+
+// IdempotencyKeyHashContainsFold applies the ContainsFold predicate on the "idempotency_key_hash" field.
+func IdempotencyKeyHashContainsFold(v string) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldContainsFold(FieldIdempotencyKeyHash, v))
+}
+
+// RequestFingerprintEQ applies the EQ predicate on the "request_fingerprint" field.
+func RequestFingerprintEQ(v string) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldEQ(FieldRequestFingerprint, v))
+}
+
+// RequestFingerprintNEQ applies the NEQ predicate on the "request_fingerprint" field.
+func RequestFingerprintNEQ(v string) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldNEQ(FieldRequestFingerprint, v))
+}
+
+// RequestFingerprintIn applies the In predicate on the "request_fingerprint" field.
+func RequestFingerprintIn(vs ...string) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldIn(FieldRequestFingerprint, vs...))
+}
+
+// RequestFingerprintNotIn applies the NotIn predicate on the "request_fingerprint" field.
+func RequestFingerprintNotIn(vs ...string) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldNotIn(FieldRequestFingerprint, vs...))
+}
+
+// RequestFingerprintGT applies the GT predicate on the "request_fingerprint" field.
+func RequestFingerprintGT(v string) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldGT(FieldRequestFingerprint, v))
+}
+
+// RequestFingerprintGTE applies the GTE predicate on the "request_fingerprint" field.
+func RequestFingerprintGTE(v string) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldGTE(FieldRequestFingerprint, v))
+}
+
+// RequestFingerprintLT applies the LT predicate on the "request_fingerprint" field.
+func RequestFingerprintLT(v string) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldLT(FieldRequestFingerprint, v))
+}
+
+// RequestFingerprintLTE applies the LTE predicate on the "request_fingerprint" field.
+func RequestFingerprintLTE(v string) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldLTE(FieldRequestFingerprint, v))
+}
+
+// RequestFingerprintContains applies the Contains predicate on the "request_fingerprint" field.
+func RequestFingerprintContains(v string) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldContains(FieldRequestFingerprint, v))
+}
+
+// RequestFingerprintHasPrefix applies the HasPrefix predicate on the "request_fingerprint" field.
+func RequestFingerprintHasPrefix(v string) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldHasPrefix(FieldRequestFingerprint, v))
+}
+
+// RequestFingerprintHasSuffix applies the HasSuffix predicate on the "request_fingerprint" field.
+func RequestFingerprintHasSuffix(v string) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldHasSuffix(FieldRequestFingerprint, v))
+}
+
+// RequestFingerprintEqualFold applies the EqualFold predicate on the "request_fingerprint" field.
+func RequestFingerprintEqualFold(v string) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldEqualFold(FieldRequestFingerprint, v))
+}
+
+// RequestFingerprintContainsFold applies the ContainsFold predicate on the "request_fingerprint" field.
+func RequestFingerprintContainsFold(v string) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldContainsFold(FieldRequestFingerprint, v))
+}
+
+// StatusEQ applies the EQ predicate on the "status" field.
+func StatusEQ(v string) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldEQ(FieldStatus, v))
+}
+
+// StatusNEQ applies the NEQ predicate on the "status" field.
+func StatusNEQ(v string) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldNEQ(FieldStatus, v))
+}
+
+// StatusIn applies the In predicate on the "status" field.
+func StatusIn(vs ...string) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldIn(FieldStatus, vs...))
+}
+
+// StatusNotIn applies the NotIn predicate on the "status" field.
+func StatusNotIn(vs ...string) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldNotIn(FieldStatus, vs...))
+}
+
+// StatusGT applies the GT predicate on the "status" field.
+func StatusGT(v string) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldGT(FieldStatus, v))
+}
+
+// StatusGTE applies the GTE predicate on the "status" field.
+func StatusGTE(v string) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldGTE(FieldStatus, v))
+}
+
+// StatusLT applies the LT predicate on the "status" field.
+func StatusLT(v string) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldLT(FieldStatus, v))
+}
+
+// StatusLTE applies the LTE predicate on the "status" field.
+func StatusLTE(v string) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldLTE(FieldStatus, v))
+}
+
+// StatusContains applies the Contains predicate on the "status" field.
+func StatusContains(v string) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldContains(FieldStatus, v))
+}
+
+// StatusHasPrefix applies the HasPrefix predicate on the "status" field.
+func StatusHasPrefix(v string) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldHasPrefix(FieldStatus, v))
+}
+
+// StatusHasSuffix applies the HasSuffix predicate on the "status" field.
+func StatusHasSuffix(v string) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldHasSuffix(FieldStatus, v))
+}
+
+// StatusEqualFold applies the EqualFold predicate on the "status" field.
+func StatusEqualFold(v string) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldEqualFold(FieldStatus, v))
+}
+
+// StatusContainsFold applies the ContainsFold predicate on the "status" field.
+func StatusContainsFold(v string) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldContainsFold(FieldStatus, v))
+}
+
+// ResponseStatusEQ applies the EQ predicate on the "response_status" field.
+func ResponseStatusEQ(v int) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldEQ(FieldResponseStatus, v))
+}
+
+// ResponseStatusNEQ applies the NEQ predicate on the "response_status" field.
+func ResponseStatusNEQ(v int) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldNEQ(FieldResponseStatus, v))
+}
+
+// ResponseStatusIn applies the In predicate on the "response_status" field.
+func ResponseStatusIn(vs ...int) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldIn(FieldResponseStatus, vs...))
+}
+
+// ResponseStatusNotIn applies the NotIn predicate on the "response_status" field.
+func ResponseStatusNotIn(vs ...int) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldNotIn(FieldResponseStatus, vs...))
+}
+
+// ResponseStatusGT applies the GT predicate on the "response_status" field.
+func ResponseStatusGT(v int) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldGT(FieldResponseStatus, v))
+}
+
+// ResponseStatusGTE applies the GTE predicate on the "response_status" field.
+func ResponseStatusGTE(v int) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldGTE(FieldResponseStatus, v))
+}
+
+// ResponseStatusLT applies the LT predicate on the "response_status" field.
+func ResponseStatusLT(v int) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldLT(FieldResponseStatus, v))
+}
+
+// ResponseStatusLTE applies the LTE predicate on the "response_status" field.
+func ResponseStatusLTE(v int) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldLTE(FieldResponseStatus, v))
+}
+
+// ResponseStatusIsNil applies the IsNil predicate on the "response_status" field.
+func ResponseStatusIsNil() predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldIsNull(FieldResponseStatus))
+}
+
+// ResponseStatusNotNil applies the NotNil predicate on the "response_status" field.
+func ResponseStatusNotNil() predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldNotNull(FieldResponseStatus))
+}
+
+// ResponseBodyEQ applies the EQ predicate on the "response_body" field.
+func ResponseBodyEQ(v string) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldEQ(FieldResponseBody, v))
+}
+
+// ResponseBodyNEQ applies the NEQ predicate on the "response_body" field.
+func ResponseBodyNEQ(v string) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldNEQ(FieldResponseBody, v))
+}
+
+// ResponseBodyIn applies the In predicate on the "response_body" field.
+func ResponseBodyIn(vs ...string) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldIn(FieldResponseBody, vs...))
+}
+
+// ResponseBodyNotIn applies the NotIn predicate on the "response_body" field.
+func ResponseBodyNotIn(vs ...string) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldNotIn(FieldResponseBody, vs...))
+}
+
+// ResponseBodyGT applies the GT predicate on the "response_body" field.
+func ResponseBodyGT(v string) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldGT(FieldResponseBody, v))
+}
+
+// ResponseBodyGTE applies the GTE predicate on the "response_body" field.
+func ResponseBodyGTE(v string) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldGTE(FieldResponseBody, v))
+}
+
+// ResponseBodyLT applies the LT predicate on the "response_body" field.
+func ResponseBodyLT(v string) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldLT(FieldResponseBody, v))
+}
+
+// ResponseBodyLTE applies the LTE predicate on the "response_body" field.
+func ResponseBodyLTE(v string) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldLTE(FieldResponseBody, v))
+}
+
+// ResponseBodyContains applies the Contains predicate on the "response_body" field.
+func ResponseBodyContains(v string) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldContains(FieldResponseBody, v))
+}
+
+// ResponseBodyHasPrefix applies the HasPrefix predicate on the "response_body" field.
+func ResponseBodyHasPrefix(v string) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldHasPrefix(FieldResponseBody, v))
+}
+
+// ResponseBodyHasSuffix applies the HasSuffix predicate on the "response_body" field.
+func ResponseBodyHasSuffix(v string) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldHasSuffix(FieldResponseBody, v))
+}
+
+// ResponseBodyIsNil applies the IsNil predicate on the "response_body" field.
+func ResponseBodyIsNil() predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldIsNull(FieldResponseBody))
+}
+
+// ResponseBodyNotNil applies the NotNil predicate on the "response_body" field.
+func ResponseBodyNotNil() predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldNotNull(FieldResponseBody))
+}
+
+// ResponseBodyEqualFold applies the EqualFold predicate on the "response_body" field.
+func ResponseBodyEqualFold(v string) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldEqualFold(FieldResponseBody, v))
+}
+
+// ResponseBodyContainsFold applies the ContainsFold predicate on the "response_body" field.
+func ResponseBodyContainsFold(v string) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldContainsFold(FieldResponseBody, v))
+}
+
+// ErrorReasonEQ applies the EQ predicate on the "error_reason" field.
+func ErrorReasonEQ(v string) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldEQ(FieldErrorReason, v))
+}
+
+// ErrorReasonNEQ applies the NEQ predicate on the "error_reason" field.
+func ErrorReasonNEQ(v string) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldNEQ(FieldErrorReason, v))
+}
+
+// ErrorReasonIn applies the In predicate on the "error_reason" field.
+func ErrorReasonIn(vs ...string) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldIn(FieldErrorReason, vs...))
+}
+
+// ErrorReasonNotIn applies the NotIn predicate on the "error_reason" field.
+func ErrorReasonNotIn(vs ...string) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldNotIn(FieldErrorReason, vs...))
+}
+
+// ErrorReasonGT applies the GT predicate on the "error_reason" field.
+func ErrorReasonGT(v string) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldGT(FieldErrorReason, v))
+}
+
+// ErrorReasonGTE applies the GTE predicate on the "error_reason" field.
+func ErrorReasonGTE(v string) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldGTE(FieldErrorReason, v))
+}
+
+// ErrorReasonLT applies the LT predicate on the "error_reason" field.
+func ErrorReasonLT(v string) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldLT(FieldErrorReason, v))
+}
+
+// ErrorReasonLTE applies the LTE predicate on the "error_reason" field.
+func ErrorReasonLTE(v string) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldLTE(FieldErrorReason, v))
+}
+
+// ErrorReasonContains applies the Contains predicate on the "error_reason" field.
+func ErrorReasonContains(v string) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldContains(FieldErrorReason, v))
+}
+
+// ErrorReasonHasPrefix applies the HasPrefix predicate on the "error_reason" field.
+func ErrorReasonHasPrefix(v string) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldHasPrefix(FieldErrorReason, v))
+}
+
+// ErrorReasonHasSuffix applies the HasSuffix predicate on the "error_reason" field.
+func ErrorReasonHasSuffix(v string) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldHasSuffix(FieldErrorReason, v))
+}
+
+// ErrorReasonIsNil applies the IsNil predicate on the "error_reason" field.
+func ErrorReasonIsNil() predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldIsNull(FieldErrorReason))
+}
+
+// ErrorReasonNotNil applies the NotNil predicate on the "error_reason" field.
+func ErrorReasonNotNil() predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldNotNull(FieldErrorReason))
+}
+
+// ErrorReasonEqualFold applies the EqualFold predicate on the "error_reason" field.
+func ErrorReasonEqualFold(v string) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldEqualFold(FieldErrorReason, v))
+}
+
+// ErrorReasonContainsFold applies the ContainsFold predicate on the "error_reason" field.
+func ErrorReasonContainsFold(v string) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldContainsFold(FieldErrorReason, v))
+}
+
+// LockedUntilEQ applies the EQ predicate on the "locked_until" field.
+func LockedUntilEQ(v time.Time) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldEQ(FieldLockedUntil, v))
+}
+
+// LockedUntilNEQ applies the NEQ predicate on the "locked_until" field.
+func LockedUntilNEQ(v time.Time) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldNEQ(FieldLockedUntil, v))
+}
+
+// LockedUntilIn applies the In predicate on the "locked_until" field.
+func LockedUntilIn(vs ...time.Time) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldIn(FieldLockedUntil, vs...))
+}
+
+// LockedUntilNotIn applies the NotIn predicate on the "locked_until" field.
+func LockedUntilNotIn(vs ...time.Time) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldNotIn(FieldLockedUntil, vs...))
+}
+
+// LockedUntilGT applies the GT predicate on the "locked_until" field.
+func LockedUntilGT(v time.Time) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldGT(FieldLockedUntil, v))
+}
+
+// LockedUntilGTE applies the GTE predicate on the "locked_until" field.
+func LockedUntilGTE(v time.Time) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldGTE(FieldLockedUntil, v))
+}
+
+// LockedUntilLT applies the LT predicate on the "locked_until" field.
+func LockedUntilLT(v time.Time) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldLT(FieldLockedUntil, v))
+}
+
+// LockedUntilLTE applies the LTE predicate on the "locked_until" field.
+func LockedUntilLTE(v time.Time) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldLTE(FieldLockedUntil, v))
+}
+
+// LockedUntilIsNil applies the IsNil predicate on the "locked_until" field.
+func LockedUntilIsNil() predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldIsNull(FieldLockedUntil))
+}
+
+// LockedUntilNotNil applies the NotNil predicate on the "locked_until" field.
+func LockedUntilNotNil() predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldNotNull(FieldLockedUntil))
+}
+
+// ExpiresAtEQ applies the EQ predicate on the "expires_at" field.
+func ExpiresAtEQ(v time.Time) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldEQ(FieldExpiresAt, v))
+}
+
+// ExpiresAtNEQ applies the NEQ predicate on the "expires_at" field.
+func ExpiresAtNEQ(v time.Time) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldNEQ(FieldExpiresAt, v))
+}
+
+// ExpiresAtIn applies the In predicate on the "expires_at" field.
+func ExpiresAtIn(vs ...time.Time) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldIn(FieldExpiresAt, vs...))
+}
+
+// ExpiresAtNotIn applies the NotIn predicate on the "expires_at" field.
+func ExpiresAtNotIn(vs ...time.Time) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldNotIn(FieldExpiresAt, vs...))
+}
+
+// ExpiresAtGT applies the GT predicate on the "expires_at" field.
+func ExpiresAtGT(v time.Time) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldGT(FieldExpiresAt, v))
+}
+
+// ExpiresAtGTE applies the GTE predicate on the "expires_at" field.
+func ExpiresAtGTE(v time.Time) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldGTE(FieldExpiresAt, v))
+}
+
+// ExpiresAtLT applies the LT predicate on the "expires_at" field.
+func ExpiresAtLT(v time.Time) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldLT(FieldExpiresAt, v))
+}
+
+// ExpiresAtLTE applies the LTE predicate on the "expires_at" field.
+func ExpiresAtLTE(v time.Time) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.FieldLTE(FieldExpiresAt, v))
+}
+
+// And groups predicates with the AND operator between them.
+func And(predicates ...predicate.IdempotencyRecord) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.AndPredicates(predicates...))
+}
+
+// Or groups predicates with the OR operator between them.
+func Or(predicates ...predicate.IdempotencyRecord) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.OrPredicates(predicates...))
+}
+
+// Not applies the not operator on the given predicate.
+func Not(p predicate.IdempotencyRecord) predicate.IdempotencyRecord {
+ return predicate.IdempotencyRecord(sql.NotPredicates(p))
+}
diff --git a/backend/ent/idempotencyrecord_create.go b/backend/ent/idempotencyrecord_create.go
new file mode 100644
index 00000000..bf4deaf2
--- /dev/null
+++ b/backend/ent/idempotencyrecord_create.go
@@ -0,0 +1,1132 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/idempotencyrecord"
+)
+
+// IdempotencyRecordCreate is the builder for creating a IdempotencyRecord entity.
+type IdempotencyRecordCreate struct {
+ config
+ mutation *IdempotencyRecordMutation
+ hooks []Hook
+ conflict []sql.ConflictOption
+}
+
+// SetCreatedAt sets the "created_at" field.
+func (_c *IdempotencyRecordCreate) SetCreatedAt(v time.Time) *IdempotencyRecordCreate {
+ _c.mutation.SetCreatedAt(v)
+ return _c
+}
+
+// SetNillableCreatedAt sets the "created_at" field if the given value is not nil.
+func (_c *IdempotencyRecordCreate) SetNillableCreatedAt(v *time.Time) *IdempotencyRecordCreate {
+ if v != nil {
+ _c.SetCreatedAt(*v)
+ }
+ return _c
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_c *IdempotencyRecordCreate) SetUpdatedAt(v time.Time) *IdempotencyRecordCreate {
+ _c.mutation.SetUpdatedAt(v)
+ return _c
+}
+
+// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil.
+func (_c *IdempotencyRecordCreate) SetNillableUpdatedAt(v *time.Time) *IdempotencyRecordCreate {
+ if v != nil {
+ _c.SetUpdatedAt(*v)
+ }
+ return _c
+}
+
+// SetScope sets the "scope" field.
+func (_c *IdempotencyRecordCreate) SetScope(v string) *IdempotencyRecordCreate {
+ _c.mutation.SetScope(v)
+ return _c
+}
+
+// SetIdempotencyKeyHash sets the "idempotency_key_hash" field.
+func (_c *IdempotencyRecordCreate) SetIdempotencyKeyHash(v string) *IdempotencyRecordCreate {
+ _c.mutation.SetIdempotencyKeyHash(v)
+ return _c
+}
+
+// SetRequestFingerprint sets the "request_fingerprint" field.
+func (_c *IdempotencyRecordCreate) SetRequestFingerprint(v string) *IdempotencyRecordCreate {
+ _c.mutation.SetRequestFingerprint(v)
+ return _c
+}
+
+// SetStatus sets the "status" field.
+func (_c *IdempotencyRecordCreate) SetStatus(v string) *IdempotencyRecordCreate {
+ _c.mutation.SetStatus(v)
+ return _c
+}
+
+// SetResponseStatus sets the "response_status" field.
+func (_c *IdempotencyRecordCreate) SetResponseStatus(v int) *IdempotencyRecordCreate {
+ _c.mutation.SetResponseStatus(v)
+ return _c
+}
+
+// SetNillableResponseStatus sets the "response_status" field if the given value is not nil.
+func (_c *IdempotencyRecordCreate) SetNillableResponseStatus(v *int) *IdempotencyRecordCreate {
+ if v != nil {
+ _c.SetResponseStatus(*v)
+ }
+ return _c
+}
+
+// SetResponseBody sets the "response_body" field.
+func (_c *IdempotencyRecordCreate) SetResponseBody(v string) *IdempotencyRecordCreate {
+ _c.mutation.SetResponseBody(v)
+ return _c
+}
+
+// SetNillableResponseBody sets the "response_body" field if the given value is not nil.
+func (_c *IdempotencyRecordCreate) SetNillableResponseBody(v *string) *IdempotencyRecordCreate {
+ if v != nil {
+ _c.SetResponseBody(*v)
+ }
+ return _c
+}
+
+// SetErrorReason sets the "error_reason" field.
+func (_c *IdempotencyRecordCreate) SetErrorReason(v string) *IdempotencyRecordCreate {
+ _c.mutation.SetErrorReason(v)
+ return _c
+}
+
+// SetNillableErrorReason sets the "error_reason" field if the given value is not nil.
+func (_c *IdempotencyRecordCreate) SetNillableErrorReason(v *string) *IdempotencyRecordCreate {
+ if v != nil {
+ _c.SetErrorReason(*v)
+ }
+ return _c
+}
+
+// SetLockedUntil sets the "locked_until" field.
+func (_c *IdempotencyRecordCreate) SetLockedUntil(v time.Time) *IdempotencyRecordCreate {
+ _c.mutation.SetLockedUntil(v)
+ return _c
+}
+
+// SetNillableLockedUntil sets the "locked_until" field if the given value is not nil.
+func (_c *IdempotencyRecordCreate) SetNillableLockedUntil(v *time.Time) *IdempotencyRecordCreate {
+ if v != nil {
+ _c.SetLockedUntil(*v)
+ }
+ return _c
+}
+
+// SetExpiresAt sets the "expires_at" field.
+func (_c *IdempotencyRecordCreate) SetExpiresAt(v time.Time) *IdempotencyRecordCreate {
+ _c.mutation.SetExpiresAt(v)
+ return _c
+}
+
+// Mutation returns the IdempotencyRecordMutation object of the builder.
+func (_c *IdempotencyRecordCreate) Mutation() *IdempotencyRecordMutation {
+ return _c.mutation
+}
+
+// Save creates the IdempotencyRecord in the database.
+func (_c *IdempotencyRecordCreate) Save(ctx context.Context) (*IdempotencyRecord, error) {
+ _c.defaults()
+ return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks)
+}
+
+// SaveX calls Save and panics if Save returns an error.
+func (_c *IdempotencyRecordCreate) SaveX(ctx context.Context) *IdempotencyRecord {
+ v, err := _c.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return v
+}
+
+// Exec executes the query.
+func (_c *IdempotencyRecordCreate) Exec(ctx context.Context) error {
+ _, err := _c.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_c *IdempotencyRecordCreate) ExecX(ctx context.Context) {
+ if err := _c.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_c *IdempotencyRecordCreate) defaults() {
+ if _, ok := _c.mutation.CreatedAt(); !ok {
+ v := idempotencyrecord.DefaultCreatedAt()
+ _c.mutation.SetCreatedAt(v)
+ }
+ if _, ok := _c.mutation.UpdatedAt(); !ok {
+ v := idempotencyrecord.DefaultUpdatedAt()
+ _c.mutation.SetUpdatedAt(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_c *IdempotencyRecordCreate) check() error {
+ if _, ok := _c.mutation.CreatedAt(); !ok {
+ return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "IdempotencyRecord.created_at"`)}
+ }
+ if _, ok := _c.mutation.UpdatedAt(); !ok {
+ return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "IdempotencyRecord.updated_at"`)}
+ }
+ if _, ok := _c.mutation.Scope(); !ok {
+ return &ValidationError{Name: "scope", err: errors.New(`ent: missing required field "IdempotencyRecord.scope"`)}
+ }
+ if v, ok := _c.mutation.Scope(); ok {
+ if err := idempotencyrecord.ScopeValidator(v); err != nil {
+ return &ValidationError{Name: "scope", err: fmt.Errorf(`ent: validator failed for field "IdempotencyRecord.scope": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.IdempotencyKeyHash(); !ok {
+ return &ValidationError{Name: "idempotency_key_hash", err: errors.New(`ent: missing required field "IdempotencyRecord.idempotency_key_hash"`)}
+ }
+ if v, ok := _c.mutation.IdempotencyKeyHash(); ok {
+ if err := idempotencyrecord.IdempotencyKeyHashValidator(v); err != nil {
+ return &ValidationError{Name: "idempotency_key_hash", err: fmt.Errorf(`ent: validator failed for field "IdempotencyRecord.idempotency_key_hash": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.RequestFingerprint(); !ok {
+ return &ValidationError{Name: "request_fingerprint", err: errors.New(`ent: missing required field "IdempotencyRecord.request_fingerprint"`)}
+ }
+ if v, ok := _c.mutation.RequestFingerprint(); ok {
+ if err := idempotencyrecord.RequestFingerprintValidator(v); err != nil {
+ return &ValidationError{Name: "request_fingerprint", err: fmt.Errorf(`ent: validator failed for field "IdempotencyRecord.request_fingerprint": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.Status(); !ok {
+ return &ValidationError{Name: "status", err: errors.New(`ent: missing required field "IdempotencyRecord.status"`)}
+ }
+ if v, ok := _c.mutation.Status(); ok {
+ if err := idempotencyrecord.StatusValidator(v); err != nil {
+ return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "IdempotencyRecord.status": %w`, err)}
+ }
+ }
+ if v, ok := _c.mutation.ErrorReason(); ok {
+ if err := idempotencyrecord.ErrorReasonValidator(v); err != nil {
+ return &ValidationError{Name: "error_reason", err: fmt.Errorf(`ent: validator failed for field "IdempotencyRecord.error_reason": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.ExpiresAt(); !ok {
+ return &ValidationError{Name: "expires_at", err: errors.New(`ent: missing required field "IdempotencyRecord.expires_at"`)}
+ }
+ return nil
+}
+
+func (_c *IdempotencyRecordCreate) sqlSave(ctx context.Context) (*IdempotencyRecord, error) {
+ if err := _c.check(); err != nil {
+ return nil, err
+ }
+ _node, _spec := _c.createSpec()
+ if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil {
+ if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return nil, err
+ }
+ id := _spec.ID.Value.(int64)
+ _node.ID = int64(id)
+ _c.mutation.id = &_node.ID
+ _c.mutation.done = true
+ return _node, nil
+}
+
+func (_c *IdempotencyRecordCreate) createSpec() (*IdempotencyRecord, *sqlgraph.CreateSpec) {
+ var (
+ _node = &IdempotencyRecord{config: _c.config}
+ _spec = sqlgraph.NewCreateSpec(idempotencyrecord.Table, sqlgraph.NewFieldSpec(idempotencyrecord.FieldID, field.TypeInt64))
+ )
+ _spec.OnConflict = _c.conflict
+ if value, ok := _c.mutation.CreatedAt(); ok {
+ _spec.SetField(idempotencyrecord.FieldCreatedAt, field.TypeTime, value)
+ _node.CreatedAt = value
+ }
+ if value, ok := _c.mutation.UpdatedAt(); ok {
+ _spec.SetField(idempotencyrecord.FieldUpdatedAt, field.TypeTime, value)
+ _node.UpdatedAt = value
+ }
+ if value, ok := _c.mutation.Scope(); ok {
+ _spec.SetField(idempotencyrecord.FieldScope, field.TypeString, value)
+ _node.Scope = value
+ }
+ if value, ok := _c.mutation.IdempotencyKeyHash(); ok {
+ _spec.SetField(idempotencyrecord.FieldIdempotencyKeyHash, field.TypeString, value)
+ _node.IdempotencyKeyHash = value
+ }
+ if value, ok := _c.mutation.RequestFingerprint(); ok {
+ _spec.SetField(idempotencyrecord.FieldRequestFingerprint, field.TypeString, value)
+ _node.RequestFingerprint = value
+ }
+ if value, ok := _c.mutation.Status(); ok {
+ _spec.SetField(idempotencyrecord.FieldStatus, field.TypeString, value)
+ _node.Status = value
+ }
+ if value, ok := _c.mutation.ResponseStatus(); ok {
+ _spec.SetField(idempotencyrecord.FieldResponseStatus, field.TypeInt, value)
+ _node.ResponseStatus = &value
+ }
+ if value, ok := _c.mutation.ResponseBody(); ok {
+ _spec.SetField(idempotencyrecord.FieldResponseBody, field.TypeString, value)
+ _node.ResponseBody = &value
+ }
+ if value, ok := _c.mutation.ErrorReason(); ok {
+ _spec.SetField(idempotencyrecord.FieldErrorReason, field.TypeString, value)
+ _node.ErrorReason = &value
+ }
+ if value, ok := _c.mutation.LockedUntil(); ok {
+ _spec.SetField(idempotencyrecord.FieldLockedUntil, field.TypeTime, value)
+ _node.LockedUntil = &value
+ }
+ if value, ok := _c.mutation.ExpiresAt(); ok {
+ _spec.SetField(idempotencyrecord.FieldExpiresAt, field.TypeTime, value)
+ _node.ExpiresAt = value
+ }
+ return _node, _spec
+}
+
+// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
+// of the `INSERT` statement. For example:
+//
+// client.IdempotencyRecord.Create().
+// SetCreatedAt(v).
+// OnConflict(
+// // Update the row with the new values
+// // the was proposed for insertion.
+// sql.ResolveWithNewValues(),
+// ).
+// // Override some of the fields with custom
+// // update values.
+// Update(func(u *ent.IdempotencyRecordUpsert) {
+// SetCreatedAt(v+v).
+// }).
+// Exec(ctx)
+func (_c *IdempotencyRecordCreate) OnConflict(opts ...sql.ConflictOption) *IdempotencyRecordUpsertOne {
+ _c.conflict = opts
+ return &IdempotencyRecordUpsertOne{
+ create: _c,
+ }
+}
+
+// OnConflictColumns calls `OnConflict` and configures the columns
+// as conflict target. Using this option is equivalent to using:
+//
+// client.IdempotencyRecord.Create().
+// OnConflict(sql.ConflictColumns(columns...)).
+// Exec(ctx)
+func (_c *IdempotencyRecordCreate) OnConflictColumns(columns ...string) *IdempotencyRecordUpsertOne {
+ _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
+ return &IdempotencyRecordUpsertOne{
+ create: _c,
+ }
+}
+
+type (
+ // IdempotencyRecordUpsertOne is the builder for "upsert"-ing
+ // one IdempotencyRecord node.
+ IdempotencyRecordUpsertOne struct {
+ create *IdempotencyRecordCreate
+ }
+
+ // IdempotencyRecordUpsert is the "OnConflict" setter.
+ IdempotencyRecordUpsert struct {
+ *sql.UpdateSet
+ }
+)
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *IdempotencyRecordUpsert) SetUpdatedAt(v time.Time) *IdempotencyRecordUpsert {
+ u.Set(idempotencyrecord.FieldUpdatedAt, v)
+ return u
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *IdempotencyRecordUpsert) UpdateUpdatedAt() *IdempotencyRecordUpsert {
+ u.SetExcluded(idempotencyrecord.FieldUpdatedAt)
+ return u
+}
+
+// SetScope sets the "scope" field.
+func (u *IdempotencyRecordUpsert) SetScope(v string) *IdempotencyRecordUpsert {
+ u.Set(idempotencyrecord.FieldScope, v)
+ return u
+}
+
+// UpdateScope sets the "scope" field to the value that was provided on create.
+func (u *IdempotencyRecordUpsert) UpdateScope() *IdempotencyRecordUpsert {
+ u.SetExcluded(idempotencyrecord.FieldScope)
+ return u
+}
+
+// SetIdempotencyKeyHash sets the "idempotency_key_hash" field.
+func (u *IdempotencyRecordUpsert) SetIdempotencyKeyHash(v string) *IdempotencyRecordUpsert {
+ u.Set(idempotencyrecord.FieldIdempotencyKeyHash, v)
+ return u
+}
+
+// UpdateIdempotencyKeyHash sets the "idempotency_key_hash" field to the value that was provided on create.
+func (u *IdempotencyRecordUpsert) UpdateIdempotencyKeyHash() *IdempotencyRecordUpsert {
+ u.SetExcluded(idempotencyrecord.FieldIdempotencyKeyHash)
+ return u
+}
+
+// SetRequestFingerprint sets the "request_fingerprint" field.
+func (u *IdempotencyRecordUpsert) SetRequestFingerprint(v string) *IdempotencyRecordUpsert {
+ u.Set(idempotencyrecord.FieldRequestFingerprint, v)
+ return u
+}
+
+// UpdateRequestFingerprint sets the "request_fingerprint" field to the value that was provided on create.
+func (u *IdempotencyRecordUpsert) UpdateRequestFingerprint() *IdempotencyRecordUpsert {
+ u.SetExcluded(idempotencyrecord.FieldRequestFingerprint)
+ return u
+}
+
+// SetStatus sets the "status" field.
+func (u *IdempotencyRecordUpsert) SetStatus(v string) *IdempotencyRecordUpsert {
+ u.Set(idempotencyrecord.FieldStatus, v)
+ return u
+}
+
+// UpdateStatus sets the "status" field to the value that was provided on create.
+func (u *IdempotencyRecordUpsert) UpdateStatus() *IdempotencyRecordUpsert {
+ u.SetExcluded(idempotencyrecord.FieldStatus)
+ return u
+}
+
+// SetResponseStatus sets the "response_status" field.
+func (u *IdempotencyRecordUpsert) SetResponseStatus(v int) *IdempotencyRecordUpsert {
+ u.Set(idempotencyrecord.FieldResponseStatus, v)
+ return u
+}
+
+// UpdateResponseStatus sets the "response_status" field to the value that was provided on create.
+func (u *IdempotencyRecordUpsert) UpdateResponseStatus() *IdempotencyRecordUpsert {
+ u.SetExcluded(idempotencyrecord.FieldResponseStatus)
+ return u
+}
+
+// AddResponseStatus adds v to the "response_status" field.
+func (u *IdempotencyRecordUpsert) AddResponseStatus(v int) *IdempotencyRecordUpsert {
+ u.Add(idempotencyrecord.FieldResponseStatus, v)
+ return u
+}
+
+// ClearResponseStatus clears the value of the "response_status" field.
+func (u *IdempotencyRecordUpsert) ClearResponseStatus() *IdempotencyRecordUpsert {
+ u.SetNull(idempotencyrecord.FieldResponseStatus)
+ return u
+}
+
+// SetResponseBody sets the "response_body" field.
+func (u *IdempotencyRecordUpsert) SetResponseBody(v string) *IdempotencyRecordUpsert {
+ u.Set(idempotencyrecord.FieldResponseBody, v)
+ return u
+}
+
+// UpdateResponseBody sets the "response_body" field to the value that was provided on create.
+func (u *IdempotencyRecordUpsert) UpdateResponseBody() *IdempotencyRecordUpsert {
+ u.SetExcluded(idempotencyrecord.FieldResponseBody)
+ return u
+}
+
+// ClearResponseBody clears the value of the "response_body" field.
+func (u *IdempotencyRecordUpsert) ClearResponseBody() *IdempotencyRecordUpsert {
+ u.SetNull(idempotencyrecord.FieldResponseBody)
+ return u
+}
+
+// SetErrorReason sets the "error_reason" field.
+func (u *IdempotencyRecordUpsert) SetErrorReason(v string) *IdempotencyRecordUpsert {
+ u.Set(idempotencyrecord.FieldErrorReason, v)
+ return u
+}
+
+// UpdateErrorReason sets the "error_reason" field to the value that was provided on create.
+func (u *IdempotencyRecordUpsert) UpdateErrorReason() *IdempotencyRecordUpsert {
+ u.SetExcluded(idempotencyrecord.FieldErrorReason)
+ return u
+}
+
+// ClearErrorReason clears the value of the "error_reason" field.
+func (u *IdempotencyRecordUpsert) ClearErrorReason() *IdempotencyRecordUpsert {
+ u.SetNull(idempotencyrecord.FieldErrorReason)
+ return u
+}
+
+// SetLockedUntil sets the "locked_until" field.
+func (u *IdempotencyRecordUpsert) SetLockedUntil(v time.Time) *IdempotencyRecordUpsert {
+ u.Set(idempotencyrecord.FieldLockedUntil, v)
+ return u
+}
+
+// UpdateLockedUntil sets the "locked_until" field to the value that was provided on create.
+func (u *IdempotencyRecordUpsert) UpdateLockedUntil() *IdempotencyRecordUpsert {
+ u.SetExcluded(idempotencyrecord.FieldLockedUntil)
+ return u
+}
+
+// ClearLockedUntil clears the value of the "locked_until" field.
+func (u *IdempotencyRecordUpsert) ClearLockedUntil() *IdempotencyRecordUpsert {
+ u.SetNull(idempotencyrecord.FieldLockedUntil)
+ return u
+}
+
+// SetExpiresAt sets the "expires_at" field.
+func (u *IdempotencyRecordUpsert) SetExpiresAt(v time.Time) *IdempotencyRecordUpsert {
+ u.Set(idempotencyrecord.FieldExpiresAt, v)
+ return u
+}
+
+// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create.
+func (u *IdempotencyRecordUpsert) UpdateExpiresAt() *IdempotencyRecordUpsert {
+ u.SetExcluded(idempotencyrecord.FieldExpiresAt)
+ return u
+}
+
+// UpdateNewValues updates the mutable fields using the new values that were set on create.
+// Using this option is equivalent to using:
+//
+// client.IdempotencyRecord.Create().
+// OnConflict(
+// sql.ResolveWithNewValues(),
+// ).
+// Exec(ctx)
+func (u *IdempotencyRecordUpsertOne) UpdateNewValues() *IdempotencyRecordUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) {
+ if _, exists := u.create.mutation.CreatedAt(); exists {
+ s.SetIgnore(idempotencyrecord.FieldCreatedAt)
+ }
+ }))
+ return u
+}
+
+// Ignore sets each column to itself in case of conflict.
+// Using this option is equivalent to using:
+//
+// client.IdempotencyRecord.Create().
+// OnConflict(sql.ResolveWithIgnore()).
+// Exec(ctx)
+func (u *IdempotencyRecordUpsertOne) Ignore() *IdempotencyRecordUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
+ return u
+}
+
+// DoNothing configures the conflict_action to `DO NOTHING`.
+// Supported only by SQLite and PostgreSQL.
+func (u *IdempotencyRecordUpsertOne) DoNothing() *IdempotencyRecordUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.DoNothing())
+ return u
+}
+
+// Update allows overriding fields `UPDATE` values. See the IdempotencyRecordCreate.OnConflict
+// documentation for more info.
+func (u *IdempotencyRecordUpsertOne) Update(set func(*IdempotencyRecordUpsert)) *IdempotencyRecordUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
+ set(&IdempotencyRecordUpsert{UpdateSet: update})
+ }))
+ return u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *IdempotencyRecordUpsertOne) SetUpdatedAt(v time.Time) *IdempotencyRecordUpsertOne {
+ return u.Update(func(s *IdempotencyRecordUpsert) {
+ s.SetUpdatedAt(v)
+ })
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *IdempotencyRecordUpsertOne) UpdateUpdatedAt() *IdempotencyRecordUpsertOne {
+ return u.Update(func(s *IdempotencyRecordUpsert) {
+ s.UpdateUpdatedAt()
+ })
+}
+
+// SetScope sets the "scope" field.
+func (u *IdempotencyRecordUpsertOne) SetScope(v string) *IdempotencyRecordUpsertOne {
+ return u.Update(func(s *IdempotencyRecordUpsert) {
+ s.SetScope(v)
+ })
+}
+
+// UpdateScope sets the "scope" field to the value that was provided on create.
+func (u *IdempotencyRecordUpsertOne) UpdateScope() *IdempotencyRecordUpsertOne {
+ return u.Update(func(s *IdempotencyRecordUpsert) {
+ s.UpdateScope()
+ })
+}
+
+// SetIdempotencyKeyHash sets the "idempotency_key_hash" field.
+func (u *IdempotencyRecordUpsertOne) SetIdempotencyKeyHash(v string) *IdempotencyRecordUpsertOne {
+ return u.Update(func(s *IdempotencyRecordUpsert) {
+ s.SetIdempotencyKeyHash(v)
+ })
+}
+
+// UpdateIdempotencyKeyHash sets the "idempotency_key_hash" field to the value that was provided on create.
+func (u *IdempotencyRecordUpsertOne) UpdateIdempotencyKeyHash() *IdempotencyRecordUpsertOne {
+ return u.Update(func(s *IdempotencyRecordUpsert) {
+ s.UpdateIdempotencyKeyHash()
+ })
+}
+
+// SetRequestFingerprint sets the "request_fingerprint" field.
+func (u *IdempotencyRecordUpsertOne) SetRequestFingerprint(v string) *IdempotencyRecordUpsertOne {
+ return u.Update(func(s *IdempotencyRecordUpsert) {
+ s.SetRequestFingerprint(v)
+ })
+}
+
+// UpdateRequestFingerprint sets the "request_fingerprint" field to the value that was provided on create.
+func (u *IdempotencyRecordUpsertOne) UpdateRequestFingerprint() *IdempotencyRecordUpsertOne {
+ return u.Update(func(s *IdempotencyRecordUpsert) {
+ s.UpdateRequestFingerprint()
+ })
+}
+
+// SetStatus sets the "status" field.
+func (u *IdempotencyRecordUpsertOne) SetStatus(v string) *IdempotencyRecordUpsertOne {
+ return u.Update(func(s *IdempotencyRecordUpsert) {
+ s.SetStatus(v)
+ })
+}
+
+// UpdateStatus sets the "status" field to the value that was provided on create.
+func (u *IdempotencyRecordUpsertOne) UpdateStatus() *IdempotencyRecordUpsertOne {
+ return u.Update(func(s *IdempotencyRecordUpsert) {
+ s.UpdateStatus()
+ })
+}
+
+// SetResponseStatus sets the "response_status" field.
+func (u *IdempotencyRecordUpsertOne) SetResponseStatus(v int) *IdempotencyRecordUpsertOne {
+ return u.Update(func(s *IdempotencyRecordUpsert) {
+ s.SetResponseStatus(v)
+ })
+}
+
+// AddResponseStatus adds v to the "response_status" field.
+func (u *IdempotencyRecordUpsertOne) AddResponseStatus(v int) *IdempotencyRecordUpsertOne {
+ return u.Update(func(s *IdempotencyRecordUpsert) {
+ s.AddResponseStatus(v)
+ })
+}
+
+// UpdateResponseStatus sets the "response_status" field to the value that was provided on create.
+func (u *IdempotencyRecordUpsertOne) UpdateResponseStatus() *IdempotencyRecordUpsertOne {
+ return u.Update(func(s *IdempotencyRecordUpsert) {
+ s.UpdateResponseStatus()
+ })
+}
+
+// ClearResponseStatus clears the value of the "response_status" field.
+func (u *IdempotencyRecordUpsertOne) ClearResponseStatus() *IdempotencyRecordUpsertOne {
+ return u.Update(func(s *IdempotencyRecordUpsert) {
+ s.ClearResponseStatus()
+ })
+}
+
+// SetResponseBody sets the "response_body" field.
+func (u *IdempotencyRecordUpsertOne) SetResponseBody(v string) *IdempotencyRecordUpsertOne {
+ return u.Update(func(s *IdempotencyRecordUpsert) {
+ s.SetResponseBody(v)
+ })
+}
+
+// UpdateResponseBody sets the "response_body" field to the value that was provided on create.
+func (u *IdempotencyRecordUpsertOne) UpdateResponseBody() *IdempotencyRecordUpsertOne {
+ return u.Update(func(s *IdempotencyRecordUpsert) {
+ s.UpdateResponseBody()
+ })
+}
+
+// ClearResponseBody clears the value of the "response_body" field.
+func (u *IdempotencyRecordUpsertOne) ClearResponseBody() *IdempotencyRecordUpsertOne {
+ return u.Update(func(s *IdempotencyRecordUpsert) {
+ s.ClearResponseBody()
+ })
+}
+
+// SetErrorReason sets the "error_reason" field.
+func (u *IdempotencyRecordUpsertOne) SetErrorReason(v string) *IdempotencyRecordUpsertOne {
+ return u.Update(func(s *IdempotencyRecordUpsert) {
+ s.SetErrorReason(v)
+ })
+}
+
+// UpdateErrorReason sets the "error_reason" field to the value that was provided on create.
+func (u *IdempotencyRecordUpsertOne) UpdateErrorReason() *IdempotencyRecordUpsertOne {
+ return u.Update(func(s *IdempotencyRecordUpsert) {
+ s.UpdateErrorReason()
+ })
+}
+
+// ClearErrorReason clears the value of the "error_reason" field.
+func (u *IdempotencyRecordUpsertOne) ClearErrorReason() *IdempotencyRecordUpsertOne {
+ return u.Update(func(s *IdempotencyRecordUpsert) {
+ s.ClearErrorReason()
+ })
+}
+
+// SetLockedUntil sets the "locked_until" field.
+func (u *IdempotencyRecordUpsertOne) SetLockedUntil(v time.Time) *IdempotencyRecordUpsertOne {
+ return u.Update(func(s *IdempotencyRecordUpsert) {
+ s.SetLockedUntil(v)
+ })
+}
+
+// UpdateLockedUntil sets the "locked_until" field to the value that was provided on create.
+func (u *IdempotencyRecordUpsertOne) UpdateLockedUntil() *IdempotencyRecordUpsertOne {
+ return u.Update(func(s *IdempotencyRecordUpsert) {
+ s.UpdateLockedUntil()
+ })
+}
+
+// ClearLockedUntil clears the value of the "locked_until" field.
+func (u *IdempotencyRecordUpsertOne) ClearLockedUntil() *IdempotencyRecordUpsertOne {
+ return u.Update(func(s *IdempotencyRecordUpsert) {
+ s.ClearLockedUntil()
+ })
+}
+
+// SetExpiresAt sets the "expires_at" field.
+func (u *IdempotencyRecordUpsertOne) SetExpiresAt(v time.Time) *IdempotencyRecordUpsertOne {
+ return u.Update(func(s *IdempotencyRecordUpsert) {
+ s.SetExpiresAt(v)
+ })
+}
+
+// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create.
+func (u *IdempotencyRecordUpsertOne) UpdateExpiresAt() *IdempotencyRecordUpsertOne {
+ return u.Update(func(s *IdempotencyRecordUpsert) {
+ s.UpdateExpiresAt()
+ })
+}
+
+// Exec executes the query.
+func (u *IdempotencyRecordUpsertOne) Exec(ctx context.Context) error {
+ if len(u.create.conflict) == 0 {
+ return errors.New("ent: missing options for IdempotencyRecordCreate.OnConflict")
+ }
+ return u.create.Exec(ctx)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (u *IdempotencyRecordUpsertOne) ExecX(ctx context.Context) {
+ if err := u.create.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// Exec executes the UPSERT query and returns the inserted/updated ID.
+func (u *IdempotencyRecordUpsertOne) ID(ctx context.Context) (id int64, err error) {
+ node, err := u.create.Save(ctx)
+ if err != nil {
+ return id, err
+ }
+ return node.ID, nil
+}
+
+// IDX is like ID, but panics if an error occurs.
+func (u *IdempotencyRecordUpsertOne) IDX(ctx context.Context) int64 {
+ id, err := u.ID(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return id
+}
+
+// IdempotencyRecordCreateBulk is the builder for creating many IdempotencyRecord entities in bulk.
+type IdempotencyRecordCreateBulk struct {
+ config
+ err error
+ builders []*IdempotencyRecordCreate
+ conflict []sql.ConflictOption
+}
+
+// Save creates the IdempotencyRecord entities in the database.
+func (_c *IdempotencyRecordCreateBulk) Save(ctx context.Context) ([]*IdempotencyRecord, error) {
+ if _c.err != nil {
+ return nil, _c.err
+ }
+ specs := make([]*sqlgraph.CreateSpec, len(_c.builders))
+ nodes := make([]*IdempotencyRecord, len(_c.builders))
+ mutators := make([]Mutator, len(_c.builders))
+ for i := range _c.builders {
+ func(i int, root context.Context) {
+ builder := _c.builders[i]
+ builder.defaults()
+ var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) {
+ mutation, ok := m.(*IdempotencyRecordMutation)
+ if !ok {
+ return nil, fmt.Errorf("unexpected mutation type %T", m)
+ }
+ if err := builder.check(); err != nil {
+ return nil, err
+ }
+ builder.mutation = mutation
+ var err error
+ nodes[i], specs[i] = builder.createSpec()
+ if i < len(mutators)-1 {
+ _, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation)
+ } else {
+ spec := &sqlgraph.BatchCreateSpec{Nodes: specs}
+ spec.OnConflict = _c.conflict
+ // Invoke the actual operation on the latest mutation in the chain.
+ if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil {
+ if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ }
+ }
+ if err != nil {
+ return nil, err
+ }
+ mutation.id = &nodes[i].ID
+ if specs[i].ID.Value != nil {
+ id := specs[i].ID.Value.(int64)
+ nodes[i].ID = int64(id)
+ }
+ mutation.done = true
+ return nodes[i], nil
+ })
+ for i := len(builder.hooks) - 1; i >= 0; i-- {
+ mut = builder.hooks[i](mut)
+ }
+ mutators[i] = mut
+ }(i, ctx)
+ }
+ if len(mutators) > 0 {
+ if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil {
+ return nil, err
+ }
+ }
+ return nodes, nil
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_c *IdempotencyRecordCreateBulk) SaveX(ctx context.Context) []*IdempotencyRecord {
+ v, err := _c.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return v
+}
+
+// Exec executes the query.
+func (_c *IdempotencyRecordCreateBulk) Exec(ctx context.Context) error {
+ _, err := _c.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_c *IdempotencyRecordCreateBulk) ExecX(ctx context.Context) {
+ if err := _c.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
+// of the `INSERT` statement. For example:
+//
+// client.IdempotencyRecord.CreateBulk(builders...).
+// OnConflict(
+// // Update the row with the new values
+// // the was proposed for insertion.
+// sql.ResolveWithNewValues(),
+// ).
+// // Override some of the fields with custom
+// // update values.
+// Update(func(u *ent.IdempotencyRecordUpsert) {
+// SetCreatedAt(v+v).
+// }).
+// Exec(ctx)
+func (_c *IdempotencyRecordCreateBulk) OnConflict(opts ...sql.ConflictOption) *IdempotencyRecordUpsertBulk {
+ _c.conflict = opts
+ return &IdempotencyRecordUpsertBulk{
+ create: _c,
+ }
+}
+
+// OnConflictColumns calls `OnConflict` and configures the columns
+// as conflict target. Using this option is equivalent to using:
+//
+// client.IdempotencyRecord.Create().
+// OnConflict(sql.ConflictColumns(columns...)).
+// Exec(ctx)
+func (_c *IdempotencyRecordCreateBulk) OnConflictColumns(columns ...string) *IdempotencyRecordUpsertBulk {
+ _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
+ return &IdempotencyRecordUpsertBulk{
+ create: _c,
+ }
+}
+
+// IdempotencyRecordUpsertBulk is the builder for "upsert"-ing
+// a bulk of IdempotencyRecord nodes.
+type IdempotencyRecordUpsertBulk struct {
+ create *IdempotencyRecordCreateBulk
+}
+
+// UpdateNewValues updates the mutable fields using the new values that
+// were set on create. Using this option is equivalent to using:
+//
+// client.IdempotencyRecord.Create().
+// OnConflict(
+// sql.ResolveWithNewValues(),
+// ).
+// Exec(ctx)
+func (u *IdempotencyRecordUpsertBulk) UpdateNewValues() *IdempotencyRecordUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) {
+ for _, b := range u.create.builders {
+ if _, exists := b.mutation.CreatedAt(); exists {
+ s.SetIgnore(idempotencyrecord.FieldCreatedAt)
+ }
+ }
+ }))
+ return u
+}
+
+// Ignore sets each column to itself in case of conflict.
+// Using this option is equivalent to using:
+//
+// client.IdempotencyRecord.Create().
+// OnConflict(sql.ResolveWithIgnore()).
+// Exec(ctx)
+func (u *IdempotencyRecordUpsertBulk) Ignore() *IdempotencyRecordUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
+ return u
+}
+
+// DoNothing configures the conflict_action to `DO NOTHING`.
+// Supported only by SQLite and PostgreSQL.
+func (u *IdempotencyRecordUpsertBulk) DoNothing() *IdempotencyRecordUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.DoNothing())
+ return u
+}
+
+// Update allows overriding fields `UPDATE` values. See the IdempotencyRecordCreateBulk.OnConflict
+// documentation for more info.
+func (u *IdempotencyRecordUpsertBulk) Update(set func(*IdempotencyRecordUpsert)) *IdempotencyRecordUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
+ set(&IdempotencyRecordUpsert{UpdateSet: update})
+ }))
+ return u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *IdempotencyRecordUpsertBulk) SetUpdatedAt(v time.Time) *IdempotencyRecordUpsertBulk {
+ return u.Update(func(s *IdempotencyRecordUpsert) {
+ s.SetUpdatedAt(v)
+ })
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *IdempotencyRecordUpsertBulk) UpdateUpdatedAt() *IdempotencyRecordUpsertBulk {
+ return u.Update(func(s *IdempotencyRecordUpsert) {
+ s.UpdateUpdatedAt()
+ })
+}
+
+// SetScope sets the "scope" field.
+func (u *IdempotencyRecordUpsertBulk) SetScope(v string) *IdempotencyRecordUpsertBulk {
+ return u.Update(func(s *IdempotencyRecordUpsert) {
+ s.SetScope(v)
+ })
+}
+
+// UpdateScope sets the "scope" field to the value that was provided on create.
+func (u *IdempotencyRecordUpsertBulk) UpdateScope() *IdempotencyRecordUpsertBulk {
+ return u.Update(func(s *IdempotencyRecordUpsert) {
+ s.UpdateScope()
+ })
+}
+
+// SetIdempotencyKeyHash sets the "idempotency_key_hash" field.
+func (u *IdempotencyRecordUpsertBulk) SetIdempotencyKeyHash(v string) *IdempotencyRecordUpsertBulk {
+ return u.Update(func(s *IdempotencyRecordUpsert) {
+ s.SetIdempotencyKeyHash(v)
+ })
+}
+
+// UpdateIdempotencyKeyHash sets the "idempotency_key_hash" field to the value that was provided on create.
+func (u *IdempotencyRecordUpsertBulk) UpdateIdempotencyKeyHash() *IdempotencyRecordUpsertBulk {
+ return u.Update(func(s *IdempotencyRecordUpsert) {
+ s.UpdateIdempotencyKeyHash()
+ })
+}
+
+// SetRequestFingerprint sets the "request_fingerprint" field.
+func (u *IdempotencyRecordUpsertBulk) SetRequestFingerprint(v string) *IdempotencyRecordUpsertBulk {
+ return u.Update(func(s *IdempotencyRecordUpsert) {
+ s.SetRequestFingerprint(v)
+ })
+}
+
+// UpdateRequestFingerprint sets the "request_fingerprint" field to the value that was provided on create.
+func (u *IdempotencyRecordUpsertBulk) UpdateRequestFingerprint() *IdempotencyRecordUpsertBulk {
+ return u.Update(func(s *IdempotencyRecordUpsert) {
+ s.UpdateRequestFingerprint()
+ })
+}
+
+// SetStatus sets the "status" field.
+func (u *IdempotencyRecordUpsertBulk) SetStatus(v string) *IdempotencyRecordUpsertBulk {
+ return u.Update(func(s *IdempotencyRecordUpsert) {
+ s.SetStatus(v)
+ })
+}
+
+// UpdateStatus sets the "status" field to the value that was provided on create.
+func (u *IdempotencyRecordUpsertBulk) UpdateStatus() *IdempotencyRecordUpsertBulk {
+ return u.Update(func(s *IdempotencyRecordUpsert) {
+ s.UpdateStatus()
+ })
+}
+
+// SetResponseStatus sets the "response_status" field.
+func (u *IdempotencyRecordUpsertBulk) SetResponseStatus(v int) *IdempotencyRecordUpsertBulk {
+ return u.Update(func(s *IdempotencyRecordUpsert) {
+ s.SetResponseStatus(v)
+ })
+}
+
+// AddResponseStatus adds v to the "response_status" field.
+func (u *IdempotencyRecordUpsertBulk) AddResponseStatus(v int) *IdempotencyRecordUpsertBulk {
+ return u.Update(func(s *IdempotencyRecordUpsert) {
+ s.AddResponseStatus(v)
+ })
+}
+
+// UpdateResponseStatus sets the "response_status" field to the value that was provided on create.
+func (u *IdempotencyRecordUpsertBulk) UpdateResponseStatus() *IdempotencyRecordUpsertBulk {
+ return u.Update(func(s *IdempotencyRecordUpsert) {
+ s.UpdateResponseStatus()
+ })
+}
+
+// ClearResponseStatus clears the value of the "response_status" field.
+func (u *IdempotencyRecordUpsertBulk) ClearResponseStatus() *IdempotencyRecordUpsertBulk {
+ return u.Update(func(s *IdempotencyRecordUpsert) {
+ s.ClearResponseStatus()
+ })
+}
+
+// SetResponseBody sets the "response_body" field.
+func (u *IdempotencyRecordUpsertBulk) SetResponseBody(v string) *IdempotencyRecordUpsertBulk {
+ return u.Update(func(s *IdempotencyRecordUpsert) {
+ s.SetResponseBody(v)
+ })
+}
+
+// UpdateResponseBody sets the "response_body" field to the value that was provided on create.
+func (u *IdempotencyRecordUpsertBulk) UpdateResponseBody() *IdempotencyRecordUpsertBulk {
+ return u.Update(func(s *IdempotencyRecordUpsert) {
+ s.UpdateResponseBody()
+ })
+}
+
+// ClearResponseBody clears the value of the "response_body" field.
+func (u *IdempotencyRecordUpsertBulk) ClearResponseBody() *IdempotencyRecordUpsertBulk {
+ return u.Update(func(s *IdempotencyRecordUpsert) {
+ s.ClearResponseBody()
+ })
+}
+
+// SetErrorReason sets the "error_reason" field.
+func (u *IdempotencyRecordUpsertBulk) SetErrorReason(v string) *IdempotencyRecordUpsertBulk {
+ return u.Update(func(s *IdempotencyRecordUpsert) {
+ s.SetErrorReason(v)
+ })
+}
+
+// UpdateErrorReason sets the "error_reason" field to the value that was provided on create.
+func (u *IdempotencyRecordUpsertBulk) UpdateErrorReason() *IdempotencyRecordUpsertBulk {
+ return u.Update(func(s *IdempotencyRecordUpsert) {
+ s.UpdateErrorReason()
+ })
+}
+
+// ClearErrorReason clears the value of the "error_reason" field.
+func (u *IdempotencyRecordUpsertBulk) ClearErrorReason() *IdempotencyRecordUpsertBulk {
+ return u.Update(func(s *IdempotencyRecordUpsert) {
+ s.ClearErrorReason()
+ })
+}
+
+// SetLockedUntil sets the "locked_until" field.
+func (u *IdempotencyRecordUpsertBulk) SetLockedUntil(v time.Time) *IdempotencyRecordUpsertBulk {
+ return u.Update(func(s *IdempotencyRecordUpsert) {
+ s.SetLockedUntil(v)
+ })
+}
+
+// UpdateLockedUntil sets the "locked_until" field to the value that was provided on create.
+func (u *IdempotencyRecordUpsertBulk) UpdateLockedUntil() *IdempotencyRecordUpsertBulk {
+ return u.Update(func(s *IdempotencyRecordUpsert) {
+ s.UpdateLockedUntil()
+ })
+}
+
+// ClearLockedUntil clears the value of the "locked_until" field.
+func (u *IdempotencyRecordUpsertBulk) ClearLockedUntil() *IdempotencyRecordUpsertBulk {
+ return u.Update(func(s *IdempotencyRecordUpsert) {
+ s.ClearLockedUntil()
+ })
+}
+
+// SetExpiresAt sets the "expires_at" field.
+func (u *IdempotencyRecordUpsertBulk) SetExpiresAt(v time.Time) *IdempotencyRecordUpsertBulk {
+ return u.Update(func(s *IdempotencyRecordUpsert) {
+ s.SetExpiresAt(v)
+ })
+}
+
+// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create.
+func (u *IdempotencyRecordUpsertBulk) UpdateExpiresAt() *IdempotencyRecordUpsertBulk {
+ return u.Update(func(s *IdempotencyRecordUpsert) {
+ s.UpdateExpiresAt()
+ })
+}
+
+// Exec executes the query.
+func (u *IdempotencyRecordUpsertBulk) Exec(ctx context.Context) error {
+ if u.create.err != nil {
+ return u.create.err
+ }
+ for i, b := range u.create.builders {
+ if len(b.conflict) != 0 {
+ return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the IdempotencyRecordCreateBulk instead", i)
+ }
+ }
+ if len(u.create.conflict) == 0 {
+ return errors.New("ent: missing options for IdempotencyRecordCreateBulk.OnConflict")
+ }
+ return u.create.Exec(ctx)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (u *IdempotencyRecordUpsertBulk) ExecX(ctx context.Context) {
+ if err := u.create.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
diff --git a/backend/ent/idempotencyrecord_delete.go b/backend/ent/idempotencyrecord_delete.go
new file mode 100644
index 00000000..f5c87559
--- /dev/null
+++ b/backend/ent/idempotencyrecord_delete.go
@@ -0,0 +1,88 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/idempotencyrecord"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// IdempotencyRecordDelete is the builder for deleting a IdempotencyRecord entity.
+type IdempotencyRecordDelete struct {
+ config
+ hooks []Hook
+ mutation *IdempotencyRecordMutation
+}
+
+// Where appends a list predicates to the IdempotencyRecordDelete builder.
+func (_d *IdempotencyRecordDelete) Where(ps ...predicate.IdempotencyRecord) *IdempotencyRecordDelete {
+ _d.mutation.Where(ps...)
+ return _d
+}
+
+// Exec executes the deletion query and returns how many vertices were deleted.
+func (_d *IdempotencyRecordDelete) Exec(ctx context.Context) (int, error) {
+ return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_d *IdempotencyRecordDelete) ExecX(ctx context.Context) int {
+ n, err := _d.Exec(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return n
+}
+
+func (_d *IdempotencyRecordDelete) sqlExec(ctx context.Context) (int, error) {
+ _spec := sqlgraph.NewDeleteSpec(idempotencyrecord.Table, sqlgraph.NewFieldSpec(idempotencyrecord.FieldID, field.TypeInt64))
+ if ps := _d.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec)
+ if err != nil && sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ _d.mutation.done = true
+ return affected, err
+}
+
+// IdempotencyRecordDeleteOne is the builder for deleting a single IdempotencyRecord entity.
+type IdempotencyRecordDeleteOne struct {
+ _d *IdempotencyRecordDelete
+}
+
+// Where appends a list predicates to the IdempotencyRecordDelete builder.
+func (_d *IdempotencyRecordDeleteOne) Where(ps ...predicate.IdempotencyRecord) *IdempotencyRecordDeleteOne {
+ _d._d.mutation.Where(ps...)
+ return _d
+}
+
+// Exec executes the deletion query.
+func (_d *IdempotencyRecordDeleteOne) Exec(ctx context.Context) error {
+ n, err := _d._d.Exec(ctx)
+ switch {
+ case err != nil:
+ return err
+ case n == 0:
+ return &NotFoundError{idempotencyrecord.Label}
+ default:
+ return nil
+ }
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_d *IdempotencyRecordDeleteOne) ExecX(ctx context.Context) {
+ if err := _d.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
diff --git a/backend/ent/idempotencyrecord_query.go b/backend/ent/idempotencyrecord_query.go
new file mode 100644
index 00000000..fbba4dfa
--- /dev/null
+++ b/backend/ent/idempotencyrecord_query.go
@@ -0,0 +1,564 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "fmt"
+ "math"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect"
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/idempotencyrecord"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// IdempotencyRecordQuery is the builder for querying IdempotencyRecord entities.
+type IdempotencyRecordQuery struct {
+ config
+ ctx *QueryContext
+ order []idempotencyrecord.OrderOption
+ inters []Interceptor
+ predicates []predicate.IdempotencyRecord
+ modifiers []func(*sql.Selector)
+ // intermediate query (i.e. traversal path).
+ sql *sql.Selector
+ path func(context.Context) (*sql.Selector, error)
+}
+
+// Where adds a new predicate for the IdempotencyRecordQuery builder.
+func (_q *IdempotencyRecordQuery) Where(ps ...predicate.IdempotencyRecord) *IdempotencyRecordQuery {
+ _q.predicates = append(_q.predicates, ps...)
+ return _q
+}
+
+// Limit the number of records to be returned by this query.
+func (_q *IdempotencyRecordQuery) Limit(limit int) *IdempotencyRecordQuery {
+ _q.ctx.Limit = &limit
+ return _q
+}
+
+// Offset to start from.
+func (_q *IdempotencyRecordQuery) Offset(offset int) *IdempotencyRecordQuery {
+ _q.ctx.Offset = &offset
+ return _q
+}
+
+// Unique configures the query builder to filter duplicate records on query.
+// By default, unique is set to true, and can be disabled using this method.
+func (_q *IdempotencyRecordQuery) Unique(unique bool) *IdempotencyRecordQuery {
+ _q.ctx.Unique = &unique
+ return _q
+}
+
+// Order specifies how the records should be ordered.
+func (_q *IdempotencyRecordQuery) Order(o ...idempotencyrecord.OrderOption) *IdempotencyRecordQuery {
+ _q.order = append(_q.order, o...)
+ return _q
+}
+
+// First returns the first IdempotencyRecord entity from the query.
+// Returns a *NotFoundError when no IdempotencyRecord was found.
+func (_q *IdempotencyRecordQuery) First(ctx context.Context) (*IdempotencyRecord, error) {
+ nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst))
+ if err != nil {
+ return nil, err
+ }
+ if len(nodes) == 0 {
+ return nil, &NotFoundError{idempotencyrecord.Label}
+ }
+ return nodes[0], nil
+}
+
+// FirstX is like First, but panics if an error occurs.
+func (_q *IdempotencyRecordQuery) FirstX(ctx context.Context) *IdempotencyRecord {
+ node, err := _q.First(ctx)
+ if err != nil && !IsNotFound(err) {
+ panic(err)
+ }
+ return node
+}
+
+// FirstID returns the first IdempotencyRecord ID from the query.
+// Returns a *NotFoundError when no IdempotencyRecord ID was found.
+func (_q *IdempotencyRecordQuery) FirstID(ctx context.Context) (id int64, err error) {
+ var ids []int64
+ if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil {
+ return
+ }
+ if len(ids) == 0 {
+ err = &NotFoundError{idempotencyrecord.Label}
+ return
+ }
+ return ids[0], nil
+}
+
+// FirstIDX is like FirstID, but panics if an error occurs.
+func (_q *IdempotencyRecordQuery) FirstIDX(ctx context.Context) int64 {
+ id, err := _q.FirstID(ctx)
+ if err != nil && !IsNotFound(err) {
+ panic(err)
+ }
+ return id
+}
+
+// Only returns a single IdempotencyRecord entity found by the query, ensuring it only returns one.
+// Returns a *NotSingularError when more than one IdempotencyRecord entity is found.
+// Returns a *NotFoundError when no IdempotencyRecord entities are found.
+func (_q *IdempotencyRecordQuery) Only(ctx context.Context) (*IdempotencyRecord, error) {
+ nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly))
+ if err != nil {
+ return nil, err
+ }
+ switch len(nodes) {
+ case 1:
+ return nodes[0], nil
+ case 0:
+ return nil, &NotFoundError{idempotencyrecord.Label}
+ default:
+ return nil, &NotSingularError{idempotencyrecord.Label}
+ }
+}
+
+// OnlyX is like Only, but panics if an error occurs.
+func (_q *IdempotencyRecordQuery) OnlyX(ctx context.Context) *IdempotencyRecord {
+ node, err := _q.Only(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return node
+}
+
+// OnlyID is like Only, but returns the only IdempotencyRecord ID in the query.
+// Returns a *NotSingularError when more than one IdempotencyRecord ID is found.
+// Returns a *NotFoundError when no entities are found.
+func (_q *IdempotencyRecordQuery) OnlyID(ctx context.Context) (id int64, err error) {
+ var ids []int64
+ if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil {
+ return
+ }
+ switch len(ids) {
+ case 1:
+ id = ids[0]
+ case 0:
+ err = &NotFoundError{idempotencyrecord.Label}
+ default:
+ err = &NotSingularError{idempotencyrecord.Label}
+ }
+ return
+}
+
+// OnlyIDX is like OnlyID, but panics if an error occurs.
+func (_q *IdempotencyRecordQuery) OnlyIDX(ctx context.Context) int64 {
+ id, err := _q.OnlyID(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return id
+}
+
+// All executes the query and returns a list of IdempotencyRecords.
+func (_q *IdempotencyRecordQuery) All(ctx context.Context) ([]*IdempotencyRecord, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll)
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ qr := querierAll[[]*IdempotencyRecord, *IdempotencyRecordQuery]()
+ return withInterceptors[[]*IdempotencyRecord](ctx, _q, qr, _q.inters)
+}
+
+// AllX is like All, but panics if an error occurs.
+func (_q *IdempotencyRecordQuery) AllX(ctx context.Context) []*IdempotencyRecord {
+ nodes, err := _q.All(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return nodes
+}
+
+// IDs executes the query and returns a list of IdempotencyRecord IDs.
+func (_q *IdempotencyRecordQuery) IDs(ctx context.Context) (ids []int64, err error) {
+ if _q.ctx.Unique == nil && _q.path != nil {
+ _q.Unique(true)
+ }
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs)
+ if err = _q.Select(idempotencyrecord.FieldID).Scan(ctx, &ids); err != nil {
+ return nil, err
+ }
+ return ids, nil
+}
+
+// IDsX is like IDs, but panics if an error occurs.
+func (_q *IdempotencyRecordQuery) IDsX(ctx context.Context) []int64 {
+ ids, err := _q.IDs(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return ids
+}
+
+// Count returns the count of the given query.
+func (_q *IdempotencyRecordQuery) Count(ctx context.Context) (int, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount)
+ if err := _q.prepareQuery(ctx); err != nil {
+ return 0, err
+ }
+ return withInterceptors[int](ctx, _q, querierCount[*IdempotencyRecordQuery](), _q.inters)
+}
+
+// CountX is like Count, but panics if an error occurs.
+func (_q *IdempotencyRecordQuery) CountX(ctx context.Context) int {
+ count, err := _q.Count(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return count
+}
+
+// Exist returns true if the query has elements in the graph.
+func (_q *IdempotencyRecordQuery) Exist(ctx context.Context) (bool, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist)
+ switch _, err := _q.FirstID(ctx); {
+ case IsNotFound(err):
+ return false, nil
+ case err != nil:
+ return false, fmt.Errorf("ent: check existence: %w", err)
+ default:
+ return true, nil
+ }
+}
+
+// ExistX is like Exist, but panics if an error occurs.
+func (_q *IdempotencyRecordQuery) ExistX(ctx context.Context) bool {
+ exist, err := _q.Exist(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return exist
+}
+
+// Clone returns a duplicate of the IdempotencyRecordQuery builder, including all associated steps. It can be
+// used to prepare common query builders and use them differently after the clone is made.
+func (_q *IdempotencyRecordQuery) Clone() *IdempotencyRecordQuery {
+ if _q == nil {
+ return nil
+ }
+ return &IdempotencyRecordQuery{
+ config: _q.config,
+ ctx: _q.ctx.Clone(),
+ order: append([]idempotencyrecord.OrderOption{}, _q.order...),
+ inters: append([]Interceptor{}, _q.inters...),
+ predicates: append([]predicate.IdempotencyRecord{}, _q.predicates...),
+ // clone intermediate query.
+ sql: _q.sql.Clone(),
+ path: _q.path,
+ }
+}
+
+// GroupBy is used to group vertices by one or more fields/columns.
+// It is often used with aggregate functions, like: count, max, mean, min, sum.
+//
+// Example:
+//
+// var v []struct {
+// CreatedAt time.Time `json:"created_at,omitempty"`
+// Count int `json:"count,omitempty"`
+// }
+//
+// client.IdempotencyRecord.Query().
+// GroupBy(idempotencyrecord.FieldCreatedAt).
+// Aggregate(ent.Count()).
+// Scan(ctx, &v)
+func (_q *IdempotencyRecordQuery) GroupBy(field string, fields ...string) *IdempotencyRecordGroupBy {
+ _q.ctx.Fields = append([]string{field}, fields...)
+ grbuild := &IdempotencyRecordGroupBy{build: _q}
+ grbuild.flds = &_q.ctx.Fields
+ grbuild.label = idempotencyrecord.Label
+ grbuild.scan = grbuild.Scan
+ return grbuild
+}
+
+// Select allows the selection one or more fields/columns for the given query,
+// instead of selecting all fields in the entity.
+//
+// Example:
+//
+// var v []struct {
+// CreatedAt time.Time `json:"created_at,omitempty"`
+// }
+//
+// client.IdempotencyRecord.Query().
+// Select(idempotencyrecord.FieldCreatedAt).
+// Scan(ctx, &v)
+func (_q *IdempotencyRecordQuery) Select(fields ...string) *IdempotencyRecordSelect {
+ _q.ctx.Fields = append(_q.ctx.Fields, fields...)
+ sbuild := &IdempotencyRecordSelect{IdempotencyRecordQuery: _q}
+ sbuild.label = idempotencyrecord.Label
+ sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan
+ return sbuild
+}
+
+// Aggregate returns a IdempotencyRecordSelect configured with the given aggregations.
+func (_q *IdempotencyRecordQuery) Aggregate(fns ...AggregateFunc) *IdempotencyRecordSelect {
+ return _q.Select().Aggregate(fns...)
+}
+
+func (_q *IdempotencyRecordQuery) prepareQuery(ctx context.Context) error {
+ for _, inter := range _q.inters {
+ if inter == nil {
+ return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)")
+ }
+ if trv, ok := inter.(Traverser); ok {
+ if err := trv.Traverse(ctx, _q); err != nil {
+ return err
+ }
+ }
+ }
+ for _, f := range _q.ctx.Fields {
+ if !idempotencyrecord.ValidColumn(f) {
+ return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
+ }
+ }
+ if _q.path != nil {
+ prev, err := _q.path(ctx)
+ if err != nil {
+ return err
+ }
+ _q.sql = prev
+ }
+ return nil
+}
+
+func (_q *IdempotencyRecordQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*IdempotencyRecord, error) {
+ var (
+ nodes = []*IdempotencyRecord{}
+ _spec = _q.querySpec()
+ )
+ _spec.ScanValues = func(columns []string) ([]any, error) {
+ return (*IdempotencyRecord).scanValues(nil, columns)
+ }
+ _spec.Assign = func(columns []string, values []any) error {
+ node := &IdempotencyRecord{config: _q.config}
+ nodes = append(nodes, node)
+ return node.assignValues(columns, values)
+ }
+ if len(_q.modifiers) > 0 {
+ _spec.Modifiers = _q.modifiers
+ }
+ for i := range hooks {
+ hooks[i](ctx, _spec)
+ }
+ if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil {
+ return nil, err
+ }
+ if len(nodes) == 0 {
+ return nodes, nil
+ }
+ return nodes, nil
+}
+
+func (_q *IdempotencyRecordQuery) sqlCount(ctx context.Context) (int, error) {
+ _spec := _q.querySpec()
+ if len(_q.modifiers) > 0 {
+ _spec.Modifiers = _q.modifiers
+ }
+ _spec.Node.Columns = _q.ctx.Fields
+ if len(_q.ctx.Fields) > 0 {
+ _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
+ }
+ return sqlgraph.CountNodes(ctx, _q.driver, _spec)
+}
+
+func (_q *IdempotencyRecordQuery) querySpec() *sqlgraph.QuerySpec {
+ _spec := sqlgraph.NewQuerySpec(idempotencyrecord.Table, idempotencyrecord.Columns, sqlgraph.NewFieldSpec(idempotencyrecord.FieldID, field.TypeInt64))
+ _spec.From = _q.sql
+ if unique := _q.ctx.Unique; unique != nil {
+ _spec.Unique = *unique
+ } else if _q.path != nil {
+ _spec.Unique = true
+ }
+ if fields := _q.ctx.Fields; len(fields) > 0 {
+ _spec.Node.Columns = make([]string, 0, len(fields))
+ _spec.Node.Columns = append(_spec.Node.Columns, idempotencyrecord.FieldID)
+ for i := range fields {
+ if fields[i] != idempotencyrecord.FieldID {
+ _spec.Node.Columns = append(_spec.Node.Columns, fields[i])
+ }
+ }
+ }
+ if ps := _q.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if limit := _q.ctx.Limit; limit != nil {
+ _spec.Limit = *limit
+ }
+ if offset := _q.ctx.Offset; offset != nil {
+ _spec.Offset = *offset
+ }
+ if ps := _q.order; len(ps) > 0 {
+ _spec.Order = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ return _spec
+}
+
+func (_q *IdempotencyRecordQuery) sqlQuery(ctx context.Context) *sql.Selector {
+ builder := sql.Dialect(_q.driver.Dialect())
+ t1 := builder.Table(idempotencyrecord.Table)
+ columns := _q.ctx.Fields
+ if len(columns) == 0 {
+ columns = idempotencyrecord.Columns
+ }
+ selector := builder.Select(t1.Columns(columns...)...).From(t1)
+ if _q.sql != nil {
+ selector = _q.sql
+ selector.Select(selector.Columns(columns...)...)
+ }
+ if _q.ctx.Unique != nil && *_q.ctx.Unique {
+ selector.Distinct()
+ }
+ for _, m := range _q.modifiers {
+ m(selector)
+ }
+ for _, p := range _q.predicates {
+ p(selector)
+ }
+ for _, p := range _q.order {
+ p(selector)
+ }
+ if offset := _q.ctx.Offset; offset != nil {
+ // limit is mandatory for offset clause. We start
+ // with default value, and override it below if needed.
+ selector.Offset(*offset).Limit(math.MaxInt32)
+ }
+ if limit := _q.ctx.Limit; limit != nil {
+ selector.Limit(*limit)
+ }
+ return selector
+}
+
+// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
+// updated, deleted or "selected ... for update" by other sessions, until the transaction is
+// either committed or rolled-back.
+func (_q *IdempotencyRecordQuery) ForUpdate(opts ...sql.LockOption) *IdempotencyRecordQuery {
+ if _q.driver.Dialect() == dialect.Postgres {
+ _q.Unique(false)
+ }
+ _q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
+ s.ForUpdate(opts...)
+ })
+ return _q
+}
+
+// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
+// on any rows that are read. Other sessions can read the rows, but cannot modify them
+// until your transaction commits.
+func (_q *IdempotencyRecordQuery) ForShare(opts ...sql.LockOption) *IdempotencyRecordQuery {
+ if _q.driver.Dialect() == dialect.Postgres {
+ _q.Unique(false)
+ }
+ _q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
+ s.ForShare(opts...)
+ })
+ return _q
+}
+
+// IdempotencyRecordGroupBy is the group-by builder for IdempotencyRecord entities.
+type IdempotencyRecordGroupBy struct {
+ selector
+ build *IdempotencyRecordQuery
+}
+
+// Aggregate adds the given aggregation functions to the group-by query.
+func (_g *IdempotencyRecordGroupBy) Aggregate(fns ...AggregateFunc) *IdempotencyRecordGroupBy {
+ _g.fns = append(_g.fns, fns...)
+ return _g
+}
+
+// Scan applies the selector query and scans the result into the given value.
+func (_g *IdempotencyRecordGroupBy) Scan(ctx context.Context, v any) error {
+ ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy)
+ if err := _g.build.prepareQuery(ctx); err != nil {
+ return err
+ }
+ return scanWithInterceptors[*IdempotencyRecordQuery, *IdempotencyRecordGroupBy](ctx, _g.build, _g, _g.build.inters, v)
+}
+
+func (_g *IdempotencyRecordGroupBy) sqlScan(ctx context.Context, root *IdempotencyRecordQuery, v any) error {
+ selector := root.sqlQuery(ctx).Select()
+ aggregation := make([]string, 0, len(_g.fns))
+ for _, fn := range _g.fns {
+ aggregation = append(aggregation, fn(selector))
+ }
+ if len(selector.SelectedColumns()) == 0 {
+ columns := make([]string, 0, len(*_g.flds)+len(_g.fns))
+ for _, f := range *_g.flds {
+ columns = append(columns, selector.C(f))
+ }
+ columns = append(columns, aggregation...)
+ selector.Select(columns...)
+ }
+ selector.GroupBy(selector.Columns(*_g.flds...)...)
+ if err := selector.Err(); err != nil {
+ return err
+ }
+ rows := &sql.Rows{}
+ query, args := selector.Query()
+ if err := _g.build.driver.Query(ctx, query, args, rows); err != nil {
+ return err
+ }
+ defer rows.Close()
+ return sql.ScanSlice(rows, v)
+}
+
+// IdempotencyRecordSelect is the builder for selecting fields of IdempotencyRecord entities.
+type IdempotencyRecordSelect struct {
+ *IdempotencyRecordQuery
+ selector
+}
+
+// Aggregate adds the given aggregation functions to the selector query.
+func (_s *IdempotencyRecordSelect) Aggregate(fns ...AggregateFunc) *IdempotencyRecordSelect {
+ _s.fns = append(_s.fns, fns...)
+ return _s
+}
+
+// Scan applies the selector query and scans the result into the given value.
+func (_s *IdempotencyRecordSelect) Scan(ctx context.Context, v any) error {
+ ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect)
+ if err := _s.prepareQuery(ctx); err != nil {
+ return err
+ }
+ return scanWithInterceptors[*IdempotencyRecordQuery, *IdempotencyRecordSelect](ctx, _s.IdempotencyRecordQuery, _s, _s.inters, v)
+}
+
+func (_s *IdempotencyRecordSelect) sqlScan(ctx context.Context, root *IdempotencyRecordQuery, v any) error {
+ selector := root.sqlQuery(ctx)
+ aggregation := make([]string, 0, len(_s.fns))
+ for _, fn := range _s.fns {
+ aggregation = append(aggregation, fn(selector))
+ }
+ switch n := len(*_s.selector.flds); {
+ case n == 0 && len(aggregation) > 0:
+ selector.Select(aggregation...)
+ case n != 0 && len(aggregation) > 0:
+ selector.AppendSelect(aggregation...)
+ }
+ rows := &sql.Rows{}
+ query, args := selector.Query()
+ if err := _s.driver.Query(ctx, query, args, rows); err != nil {
+ return err
+ }
+ defer rows.Close()
+ return sql.ScanSlice(rows, v)
+}
diff --git a/backend/ent/idempotencyrecord_update.go b/backend/ent/idempotencyrecord_update.go
new file mode 100644
index 00000000..f839e5c0
--- /dev/null
+++ b/backend/ent/idempotencyrecord_update.go
@@ -0,0 +1,676 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/idempotencyrecord"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// IdempotencyRecordUpdate is the builder for updating IdempotencyRecord entities.
+type IdempotencyRecordUpdate struct {
+ config
+ hooks []Hook
+ mutation *IdempotencyRecordMutation
+}
+
+// Where appends a list predicates to the IdempotencyRecordUpdate builder.
+func (_u *IdempotencyRecordUpdate) Where(ps ...predicate.IdempotencyRecord) *IdempotencyRecordUpdate {
+ _u.mutation.Where(ps...)
+ return _u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_u *IdempotencyRecordUpdate) SetUpdatedAt(v time.Time) *IdempotencyRecordUpdate {
+ _u.mutation.SetUpdatedAt(v)
+ return _u
+}
+
+// SetScope sets the "scope" field.
+func (_u *IdempotencyRecordUpdate) SetScope(v string) *IdempotencyRecordUpdate {
+ _u.mutation.SetScope(v)
+ return _u
+}
+
+// SetNillableScope sets the "scope" field if the given value is not nil.
+func (_u *IdempotencyRecordUpdate) SetNillableScope(v *string) *IdempotencyRecordUpdate {
+ if v != nil {
+ _u.SetScope(*v)
+ }
+ return _u
+}
+
+// SetIdempotencyKeyHash sets the "idempotency_key_hash" field.
+func (_u *IdempotencyRecordUpdate) SetIdempotencyKeyHash(v string) *IdempotencyRecordUpdate {
+ _u.mutation.SetIdempotencyKeyHash(v)
+ return _u
+}
+
+// SetNillableIdempotencyKeyHash sets the "idempotency_key_hash" field if the given value is not nil.
+func (_u *IdempotencyRecordUpdate) SetNillableIdempotencyKeyHash(v *string) *IdempotencyRecordUpdate {
+ if v != nil {
+ _u.SetIdempotencyKeyHash(*v)
+ }
+ return _u
+}
+
+// SetRequestFingerprint sets the "request_fingerprint" field.
+func (_u *IdempotencyRecordUpdate) SetRequestFingerprint(v string) *IdempotencyRecordUpdate {
+ _u.mutation.SetRequestFingerprint(v)
+ return _u
+}
+
+// SetNillableRequestFingerprint sets the "request_fingerprint" field if the given value is not nil.
+func (_u *IdempotencyRecordUpdate) SetNillableRequestFingerprint(v *string) *IdempotencyRecordUpdate {
+ if v != nil {
+ _u.SetRequestFingerprint(*v)
+ }
+ return _u
+}
+
+// SetStatus sets the "status" field.
+func (_u *IdempotencyRecordUpdate) SetStatus(v string) *IdempotencyRecordUpdate {
+ _u.mutation.SetStatus(v)
+ return _u
+}
+
+// SetNillableStatus sets the "status" field if the given value is not nil.
+func (_u *IdempotencyRecordUpdate) SetNillableStatus(v *string) *IdempotencyRecordUpdate {
+ if v != nil {
+ _u.SetStatus(*v)
+ }
+ return _u
+}
+
+// SetResponseStatus sets the "response_status" field.
+func (_u *IdempotencyRecordUpdate) SetResponseStatus(v int) *IdempotencyRecordUpdate {
+ _u.mutation.ResetResponseStatus()
+ _u.mutation.SetResponseStatus(v)
+ return _u
+}
+
+// SetNillableResponseStatus sets the "response_status" field if the given value is not nil.
+func (_u *IdempotencyRecordUpdate) SetNillableResponseStatus(v *int) *IdempotencyRecordUpdate {
+ if v != nil {
+ _u.SetResponseStatus(*v)
+ }
+ return _u
+}
+
+// AddResponseStatus adds value to the "response_status" field.
+func (_u *IdempotencyRecordUpdate) AddResponseStatus(v int) *IdempotencyRecordUpdate {
+ _u.mutation.AddResponseStatus(v)
+ return _u
+}
+
+// ClearResponseStatus clears the value of the "response_status" field.
+func (_u *IdempotencyRecordUpdate) ClearResponseStatus() *IdempotencyRecordUpdate {
+ _u.mutation.ClearResponseStatus()
+ return _u
+}
+
+// SetResponseBody sets the "response_body" field.
+func (_u *IdempotencyRecordUpdate) SetResponseBody(v string) *IdempotencyRecordUpdate {
+ _u.mutation.SetResponseBody(v)
+ return _u
+}
+
+// SetNillableResponseBody sets the "response_body" field if the given value is not nil.
+func (_u *IdempotencyRecordUpdate) SetNillableResponseBody(v *string) *IdempotencyRecordUpdate {
+ if v != nil {
+ _u.SetResponseBody(*v)
+ }
+ return _u
+}
+
+// ClearResponseBody clears the value of the "response_body" field.
+func (_u *IdempotencyRecordUpdate) ClearResponseBody() *IdempotencyRecordUpdate {
+ _u.mutation.ClearResponseBody()
+ return _u
+}
+
+// SetErrorReason sets the "error_reason" field.
+func (_u *IdempotencyRecordUpdate) SetErrorReason(v string) *IdempotencyRecordUpdate {
+ _u.mutation.SetErrorReason(v)
+ return _u
+}
+
+// SetNillableErrorReason sets the "error_reason" field if the given value is not nil.
+func (_u *IdempotencyRecordUpdate) SetNillableErrorReason(v *string) *IdempotencyRecordUpdate {
+ if v != nil {
+ _u.SetErrorReason(*v)
+ }
+ return _u
+}
+
+// ClearErrorReason clears the value of the "error_reason" field.
+func (_u *IdempotencyRecordUpdate) ClearErrorReason() *IdempotencyRecordUpdate {
+ _u.mutation.ClearErrorReason()
+ return _u
+}
+
+// SetLockedUntil sets the "locked_until" field.
+func (_u *IdempotencyRecordUpdate) SetLockedUntil(v time.Time) *IdempotencyRecordUpdate {
+ _u.mutation.SetLockedUntil(v)
+ return _u
+}
+
+// SetNillableLockedUntil sets the "locked_until" field if the given value is not nil.
+func (_u *IdempotencyRecordUpdate) SetNillableLockedUntil(v *time.Time) *IdempotencyRecordUpdate {
+ if v != nil {
+ _u.SetLockedUntil(*v)
+ }
+ return _u
+}
+
+// ClearLockedUntil clears the value of the "locked_until" field.
+func (_u *IdempotencyRecordUpdate) ClearLockedUntil() *IdempotencyRecordUpdate {
+ _u.mutation.ClearLockedUntil()
+ return _u
+}
+
+// SetExpiresAt sets the "expires_at" field.
+func (_u *IdempotencyRecordUpdate) SetExpiresAt(v time.Time) *IdempotencyRecordUpdate {
+ _u.mutation.SetExpiresAt(v)
+ return _u
+}
+
+// SetNillableExpiresAt sets the "expires_at" field if the given value is not nil.
+func (_u *IdempotencyRecordUpdate) SetNillableExpiresAt(v *time.Time) *IdempotencyRecordUpdate {
+ if v != nil {
+ _u.SetExpiresAt(*v)
+ }
+ return _u
+}
+
+// Mutation returns the IdempotencyRecordMutation object of the builder.
+func (_u *IdempotencyRecordUpdate) Mutation() *IdempotencyRecordMutation {
+ return _u.mutation
+}
+
+// Save executes the query and returns the number of nodes affected by the update operation.
+func (_u *IdempotencyRecordUpdate) Save(ctx context.Context) (int, error) {
+ _u.defaults()
+ return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_u *IdempotencyRecordUpdate) SaveX(ctx context.Context) int {
+ affected, err := _u.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return affected
+}
+
+// Exec executes the query.
+func (_u *IdempotencyRecordUpdate) Exec(ctx context.Context) error {
+ _, err := _u.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_u *IdempotencyRecordUpdate) ExecX(ctx context.Context) {
+ if err := _u.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_u *IdempotencyRecordUpdate) defaults() {
+ if _, ok := _u.mutation.UpdatedAt(); !ok {
+ v := idempotencyrecord.UpdateDefaultUpdatedAt()
+ _u.mutation.SetUpdatedAt(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_u *IdempotencyRecordUpdate) check() error {
+ if v, ok := _u.mutation.Scope(); ok {
+ if err := idempotencyrecord.ScopeValidator(v); err != nil {
+ return &ValidationError{Name: "scope", err: fmt.Errorf(`ent: validator failed for field "IdempotencyRecord.scope": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.IdempotencyKeyHash(); ok {
+ if err := idempotencyrecord.IdempotencyKeyHashValidator(v); err != nil {
+ return &ValidationError{Name: "idempotency_key_hash", err: fmt.Errorf(`ent: validator failed for field "IdempotencyRecord.idempotency_key_hash": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.RequestFingerprint(); ok {
+ if err := idempotencyrecord.RequestFingerprintValidator(v); err != nil {
+ return &ValidationError{Name: "request_fingerprint", err: fmt.Errorf(`ent: validator failed for field "IdempotencyRecord.request_fingerprint": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.Status(); ok {
+ if err := idempotencyrecord.StatusValidator(v); err != nil {
+ return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "IdempotencyRecord.status": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ErrorReason(); ok {
+ if err := idempotencyrecord.ErrorReasonValidator(v); err != nil {
+ return &ValidationError{Name: "error_reason", err: fmt.Errorf(`ent: validator failed for field "IdempotencyRecord.error_reason": %w`, err)}
+ }
+ }
+ return nil
+}
+
+func (_u *IdempotencyRecordUpdate) sqlSave(ctx context.Context) (_node int, err error) {
+ if err := _u.check(); err != nil {
+ return _node, err
+ }
+ _spec := sqlgraph.NewUpdateSpec(idempotencyrecord.Table, idempotencyrecord.Columns, sqlgraph.NewFieldSpec(idempotencyrecord.FieldID, field.TypeInt64))
+ if ps := _u.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if value, ok := _u.mutation.UpdatedAt(); ok {
+ _spec.SetField(idempotencyrecord.FieldUpdatedAt, field.TypeTime, value)
+ }
+ if value, ok := _u.mutation.Scope(); ok {
+ _spec.SetField(idempotencyrecord.FieldScope, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.IdempotencyKeyHash(); ok {
+ _spec.SetField(idempotencyrecord.FieldIdempotencyKeyHash, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.RequestFingerprint(); ok {
+ _spec.SetField(idempotencyrecord.FieldRequestFingerprint, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.Status(); ok {
+ _spec.SetField(idempotencyrecord.FieldStatus, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ResponseStatus(); ok {
+ _spec.SetField(idempotencyrecord.FieldResponseStatus, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.AddedResponseStatus(); ok {
+ _spec.AddField(idempotencyrecord.FieldResponseStatus, field.TypeInt, value)
+ }
+ if _u.mutation.ResponseStatusCleared() {
+ _spec.ClearField(idempotencyrecord.FieldResponseStatus, field.TypeInt)
+ }
+ if value, ok := _u.mutation.ResponseBody(); ok {
+ _spec.SetField(idempotencyrecord.FieldResponseBody, field.TypeString, value)
+ }
+ if _u.mutation.ResponseBodyCleared() {
+ _spec.ClearField(idempotencyrecord.FieldResponseBody, field.TypeString)
+ }
+ if value, ok := _u.mutation.ErrorReason(); ok {
+ _spec.SetField(idempotencyrecord.FieldErrorReason, field.TypeString, value)
+ }
+ if _u.mutation.ErrorReasonCleared() {
+ _spec.ClearField(idempotencyrecord.FieldErrorReason, field.TypeString)
+ }
+ if value, ok := _u.mutation.LockedUntil(); ok {
+ _spec.SetField(idempotencyrecord.FieldLockedUntil, field.TypeTime, value)
+ }
+ if _u.mutation.LockedUntilCleared() {
+ _spec.ClearField(idempotencyrecord.FieldLockedUntil, field.TypeTime)
+ }
+ if value, ok := _u.mutation.ExpiresAt(); ok {
+ _spec.SetField(idempotencyrecord.FieldExpiresAt, field.TypeTime, value)
+ }
+ if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
+ if _, ok := err.(*sqlgraph.NotFoundError); ok {
+ err = &NotFoundError{idempotencyrecord.Label}
+ } else if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return 0, err
+ }
+ _u.mutation.done = true
+ return _node, nil
+}
+
+// IdempotencyRecordUpdateOne is the builder for updating a single IdempotencyRecord entity.
+type IdempotencyRecordUpdateOne struct {
+ config
+ fields []string
+ hooks []Hook
+ mutation *IdempotencyRecordMutation
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_u *IdempotencyRecordUpdateOne) SetUpdatedAt(v time.Time) *IdempotencyRecordUpdateOne {
+ _u.mutation.SetUpdatedAt(v)
+ return _u
+}
+
+// SetScope sets the "scope" field.
+func (_u *IdempotencyRecordUpdateOne) SetScope(v string) *IdempotencyRecordUpdateOne {
+ _u.mutation.SetScope(v)
+ return _u
+}
+
+// SetNillableScope sets the "scope" field if the given value is not nil.
+func (_u *IdempotencyRecordUpdateOne) SetNillableScope(v *string) *IdempotencyRecordUpdateOne {
+ if v != nil {
+ _u.SetScope(*v)
+ }
+ return _u
+}
+
+// SetIdempotencyKeyHash sets the "idempotency_key_hash" field.
+func (_u *IdempotencyRecordUpdateOne) SetIdempotencyKeyHash(v string) *IdempotencyRecordUpdateOne {
+ _u.mutation.SetIdempotencyKeyHash(v)
+ return _u
+}
+
+// SetNillableIdempotencyKeyHash sets the "idempotency_key_hash" field if the given value is not nil.
+func (_u *IdempotencyRecordUpdateOne) SetNillableIdempotencyKeyHash(v *string) *IdempotencyRecordUpdateOne {
+ if v != nil {
+ _u.SetIdempotencyKeyHash(*v)
+ }
+ return _u
+}
+
+// SetRequestFingerprint sets the "request_fingerprint" field.
+func (_u *IdempotencyRecordUpdateOne) SetRequestFingerprint(v string) *IdempotencyRecordUpdateOne {
+ _u.mutation.SetRequestFingerprint(v)
+ return _u
+}
+
+// SetNillableRequestFingerprint sets the "request_fingerprint" field if the given value is not nil.
+func (_u *IdempotencyRecordUpdateOne) SetNillableRequestFingerprint(v *string) *IdempotencyRecordUpdateOne {
+ if v != nil {
+ _u.SetRequestFingerprint(*v)
+ }
+ return _u
+}
+
+// SetStatus sets the "status" field.
+func (_u *IdempotencyRecordUpdateOne) SetStatus(v string) *IdempotencyRecordUpdateOne {
+ _u.mutation.SetStatus(v)
+ return _u
+}
+
+// SetNillableStatus sets the "status" field if the given value is not nil.
+func (_u *IdempotencyRecordUpdateOne) SetNillableStatus(v *string) *IdempotencyRecordUpdateOne {
+ if v != nil {
+ _u.SetStatus(*v)
+ }
+ return _u
+}
+
+// SetResponseStatus sets the "response_status" field.
+func (_u *IdempotencyRecordUpdateOne) SetResponseStatus(v int) *IdempotencyRecordUpdateOne {
+ _u.mutation.ResetResponseStatus()
+ _u.mutation.SetResponseStatus(v)
+ return _u
+}
+
+// SetNillableResponseStatus sets the "response_status" field if the given value is not nil.
+func (_u *IdempotencyRecordUpdateOne) SetNillableResponseStatus(v *int) *IdempotencyRecordUpdateOne {
+ if v != nil {
+ _u.SetResponseStatus(*v)
+ }
+ return _u
+}
+
+// AddResponseStatus adds value to the "response_status" field.
+func (_u *IdempotencyRecordUpdateOne) AddResponseStatus(v int) *IdempotencyRecordUpdateOne {
+ _u.mutation.AddResponseStatus(v)
+ return _u
+}
+
+// ClearResponseStatus clears the value of the "response_status" field.
+func (_u *IdempotencyRecordUpdateOne) ClearResponseStatus() *IdempotencyRecordUpdateOne {
+ _u.mutation.ClearResponseStatus()
+ return _u
+}
+
+// SetResponseBody sets the "response_body" field.
+func (_u *IdempotencyRecordUpdateOne) SetResponseBody(v string) *IdempotencyRecordUpdateOne {
+ _u.mutation.SetResponseBody(v)
+ return _u
+}
+
+// SetNillableResponseBody sets the "response_body" field if the given value is not nil.
+func (_u *IdempotencyRecordUpdateOne) SetNillableResponseBody(v *string) *IdempotencyRecordUpdateOne {
+ if v != nil {
+ _u.SetResponseBody(*v)
+ }
+ return _u
+}
+
+// ClearResponseBody clears the value of the "response_body" field.
+func (_u *IdempotencyRecordUpdateOne) ClearResponseBody() *IdempotencyRecordUpdateOne {
+ _u.mutation.ClearResponseBody()
+ return _u
+}
+
+// SetErrorReason sets the "error_reason" field.
+func (_u *IdempotencyRecordUpdateOne) SetErrorReason(v string) *IdempotencyRecordUpdateOne {
+ _u.mutation.SetErrorReason(v)
+ return _u
+}
+
+// SetNillableErrorReason sets the "error_reason" field if the given value is not nil.
+func (_u *IdempotencyRecordUpdateOne) SetNillableErrorReason(v *string) *IdempotencyRecordUpdateOne {
+ if v != nil {
+ _u.SetErrorReason(*v)
+ }
+ return _u
+}
+
+// ClearErrorReason clears the value of the "error_reason" field.
+func (_u *IdempotencyRecordUpdateOne) ClearErrorReason() *IdempotencyRecordUpdateOne {
+ _u.mutation.ClearErrorReason()
+ return _u
+}
+
+// SetLockedUntil sets the "locked_until" field.
+func (_u *IdempotencyRecordUpdateOne) SetLockedUntil(v time.Time) *IdempotencyRecordUpdateOne {
+ _u.mutation.SetLockedUntil(v)
+ return _u
+}
+
+// SetNillableLockedUntil sets the "locked_until" field if the given value is not nil.
+func (_u *IdempotencyRecordUpdateOne) SetNillableLockedUntil(v *time.Time) *IdempotencyRecordUpdateOne {
+ if v != nil {
+ _u.SetLockedUntil(*v)
+ }
+ return _u
+}
+
+// ClearLockedUntil clears the value of the "locked_until" field.
+func (_u *IdempotencyRecordUpdateOne) ClearLockedUntil() *IdempotencyRecordUpdateOne {
+ _u.mutation.ClearLockedUntil()
+ return _u
+}
+
+// SetExpiresAt sets the "expires_at" field.
+func (_u *IdempotencyRecordUpdateOne) SetExpiresAt(v time.Time) *IdempotencyRecordUpdateOne {
+ _u.mutation.SetExpiresAt(v)
+ return _u
+}
+
+// SetNillableExpiresAt sets the "expires_at" field if the given value is not nil.
+func (_u *IdempotencyRecordUpdateOne) SetNillableExpiresAt(v *time.Time) *IdempotencyRecordUpdateOne {
+ if v != nil {
+ _u.SetExpiresAt(*v)
+ }
+ return _u
+}
+
+// Mutation returns the IdempotencyRecordMutation object of the builder.
+func (_u *IdempotencyRecordUpdateOne) Mutation() *IdempotencyRecordMutation {
+ return _u.mutation
+}
+
+// Where appends a list predicates to the IdempotencyRecordUpdate builder.
+func (_u *IdempotencyRecordUpdateOne) Where(ps ...predicate.IdempotencyRecord) *IdempotencyRecordUpdateOne {
+ _u.mutation.Where(ps...)
+ return _u
+}
+
+// Select allows selecting one or more fields (columns) of the returned entity.
+// The default is selecting all fields defined in the entity schema.
+func (_u *IdempotencyRecordUpdateOne) Select(field string, fields ...string) *IdempotencyRecordUpdateOne {
+ _u.fields = append([]string{field}, fields...)
+ return _u
+}
+
+// Save executes the query and returns the updated IdempotencyRecord entity.
+func (_u *IdempotencyRecordUpdateOne) Save(ctx context.Context) (*IdempotencyRecord, error) {
+ _u.defaults()
+ return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_u *IdempotencyRecordUpdateOne) SaveX(ctx context.Context) *IdempotencyRecord {
+ node, err := _u.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return node
+}
+
+// Exec executes the query on the entity.
+func (_u *IdempotencyRecordUpdateOne) Exec(ctx context.Context) error {
+ _, err := _u.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_u *IdempotencyRecordUpdateOne) ExecX(ctx context.Context) {
+ if err := _u.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_u *IdempotencyRecordUpdateOne) defaults() {
+ if _, ok := _u.mutation.UpdatedAt(); !ok {
+ v := idempotencyrecord.UpdateDefaultUpdatedAt()
+ _u.mutation.SetUpdatedAt(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_u *IdempotencyRecordUpdateOne) check() error {
+ if v, ok := _u.mutation.Scope(); ok {
+ if err := idempotencyrecord.ScopeValidator(v); err != nil {
+ return &ValidationError{Name: "scope", err: fmt.Errorf(`ent: validator failed for field "IdempotencyRecord.scope": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.IdempotencyKeyHash(); ok {
+ if err := idempotencyrecord.IdempotencyKeyHashValidator(v); err != nil {
+ return &ValidationError{Name: "idempotency_key_hash", err: fmt.Errorf(`ent: validator failed for field "IdempotencyRecord.idempotency_key_hash": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.RequestFingerprint(); ok {
+ if err := idempotencyrecord.RequestFingerprintValidator(v); err != nil {
+ return &ValidationError{Name: "request_fingerprint", err: fmt.Errorf(`ent: validator failed for field "IdempotencyRecord.request_fingerprint": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.Status(); ok {
+ if err := idempotencyrecord.StatusValidator(v); err != nil {
+ return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "IdempotencyRecord.status": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ErrorReason(); ok {
+ if err := idempotencyrecord.ErrorReasonValidator(v); err != nil {
+ return &ValidationError{Name: "error_reason", err: fmt.Errorf(`ent: validator failed for field "IdempotencyRecord.error_reason": %w`, err)}
+ }
+ }
+ return nil
+}
+
+func (_u *IdempotencyRecordUpdateOne) sqlSave(ctx context.Context) (_node *IdempotencyRecord, err error) {
+ if err := _u.check(); err != nil {
+ return _node, err
+ }
+ _spec := sqlgraph.NewUpdateSpec(idempotencyrecord.Table, idempotencyrecord.Columns, sqlgraph.NewFieldSpec(idempotencyrecord.FieldID, field.TypeInt64))
+ id, ok := _u.mutation.ID()
+ if !ok {
+ return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "IdempotencyRecord.id" for update`)}
+ }
+ _spec.Node.ID.Value = id
+ if fields := _u.fields; len(fields) > 0 {
+ _spec.Node.Columns = make([]string, 0, len(fields))
+ _spec.Node.Columns = append(_spec.Node.Columns, idempotencyrecord.FieldID)
+ for _, f := range fields {
+ if !idempotencyrecord.ValidColumn(f) {
+ return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
+ }
+ if f != idempotencyrecord.FieldID {
+ _spec.Node.Columns = append(_spec.Node.Columns, f)
+ }
+ }
+ }
+ if ps := _u.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if value, ok := _u.mutation.UpdatedAt(); ok {
+ _spec.SetField(idempotencyrecord.FieldUpdatedAt, field.TypeTime, value)
+ }
+ if value, ok := _u.mutation.Scope(); ok {
+ _spec.SetField(idempotencyrecord.FieldScope, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.IdempotencyKeyHash(); ok {
+ _spec.SetField(idempotencyrecord.FieldIdempotencyKeyHash, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.RequestFingerprint(); ok {
+ _spec.SetField(idempotencyrecord.FieldRequestFingerprint, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.Status(); ok {
+ _spec.SetField(idempotencyrecord.FieldStatus, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ResponseStatus(); ok {
+ _spec.SetField(idempotencyrecord.FieldResponseStatus, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.AddedResponseStatus(); ok {
+ _spec.AddField(idempotencyrecord.FieldResponseStatus, field.TypeInt, value)
+ }
+ if _u.mutation.ResponseStatusCleared() {
+ _spec.ClearField(idempotencyrecord.FieldResponseStatus, field.TypeInt)
+ }
+ if value, ok := _u.mutation.ResponseBody(); ok {
+ _spec.SetField(idempotencyrecord.FieldResponseBody, field.TypeString, value)
+ }
+ if _u.mutation.ResponseBodyCleared() {
+ _spec.ClearField(idempotencyrecord.FieldResponseBody, field.TypeString)
+ }
+ if value, ok := _u.mutation.ErrorReason(); ok {
+ _spec.SetField(idempotencyrecord.FieldErrorReason, field.TypeString, value)
+ }
+ if _u.mutation.ErrorReasonCleared() {
+ _spec.ClearField(idempotencyrecord.FieldErrorReason, field.TypeString)
+ }
+ if value, ok := _u.mutation.LockedUntil(); ok {
+ _spec.SetField(idempotencyrecord.FieldLockedUntil, field.TypeTime, value)
+ }
+ if _u.mutation.LockedUntilCleared() {
+ _spec.ClearField(idempotencyrecord.FieldLockedUntil, field.TypeTime)
+ }
+ if value, ok := _u.mutation.ExpiresAt(); ok {
+ _spec.SetField(idempotencyrecord.FieldExpiresAt, field.TypeTime, value)
+ }
+ _node = &IdempotencyRecord{config: _u.config}
+ _spec.Assign = _node.assignValues
+ _spec.ScanValues = _node.scanValues
+ if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil {
+ if _, ok := err.(*sqlgraph.NotFoundError); ok {
+ err = &NotFoundError{idempotencyrecord.Label}
+ } else if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return nil, err
+ }
+ _u.mutation.done = true
+ return _node, nil
+}
diff --git a/backend/ent/intercept/intercept.go b/backend/ent/intercept/intercept.go
index 290fb163..e7746402 100644
--- a/backend/ent/intercept/intercept.go
+++ b/backend/ent/intercept/intercept.go
@@ -15,6 +15,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule"
"github.com/Wei-Shaw/sub2api/ent/group"
+ "github.com/Wei-Shaw/sub2api/ent/idempotencyrecord"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/promocode"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
@@ -276,6 +277,33 @@ func (f TraverseGroup) Traverse(ctx context.Context, q ent.Query) error {
return fmt.Errorf("unexpected query type %T. expect *ent.GroupQuery", q)
}
+// The IdempotencyRecordFunc type is an adapter to allow the use of ordinary function as a Querier.
+type IdempotencyRecordFunc func(context.Context, *ent.IdempotencyRecordQuery) (ent.Value, error)
+
+// Query calls f(ctx, q).
+func (f IdempotencyRecordFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) {
+ if q, ok := q.(*ent.IdempotencyRecordQuery); ok {
+ return f(ctx, q)
+ }
+ return nil, fmt.Errorf("unexpected query type %T. expect *ent.IdempotencyRecordQuery", q)
+}
+
+// The TraverseIdempotencyRecord type is an adapter to allow the use of ordinary function as Traverser.
+type TraverseIdempotencyRecord func(context.Context, *ent.IdempotencyRecordQuery) error
+
+// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline.
+func (f TraverseIdempotencyRecord) Intercept(next ent.Querier) ent.Querier {
+ return next
+}
+
+// Traverse calls f(ctx, q).
+func (f TraverseIdempotencyRecord) Traverse(ctx context.Context, q ent.Query) error {
+ if q, ok := q.(*ent.IdempotencyRecordQuery); ok {
+ return f(ctx, q)
+ }
+ return fmt.Errorf("unexpected query type %T. expect *ent.IdempotencyRecordQuery", q)
+}
+
// The PromoCodeFunc type is an adapter to allow the use of ordinary function as a Querier.
type PromoCodeFunc func(context.Context, *ent.PromoCodeQuery) (ent.Value, error)
@@ -644,6 +672,8 @@ func NewQuery(q ent.Query) (Query, error) {
return &query[*ent.ErrorPassthroughRuleQuery, predicate.ErrorPassthroughRule, errorpassthroughrule.OrderOption]{typ: ent.TypeErrorPassthroughRule, tq: q}, nil
case *ent.GroupQuery:
return &query[*ent.GroupQuery, predicate.Group, group.OrderOption]{typ: ent.TypeGroup, tq: q}, nil
+ case *ent.IdempotencyRecordQuery:
+ return &query[*ent.IdempotencyRecordQuery, predicate.IdempotencyRecord, idempotencyrecord.OrderOption]{typ: ent.TypeIdempotencyRecord, tq: q}, nil
case *ent.PromoCodeQuery:
return &query[*ent.PromoCodeQuery, predicate.PromoCode, promocode.OrderOption]{typ: ent.TypePromoCode, tq: q}, nil
case *ent.PromoCodeUsageQuery:
diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go
index aba00d4f..769dddce 100644
--- a/backend/ent/migrate/schema.go
+++ b/backend/ent/migrate/schema.go
@@ -108,6 +108,8 @@ var (
{Name: "rate_limited_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "rate_limit_reset_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "overload_until", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "temp_unschedulable_until", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "temp_unschedulable_reason", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}},
{Name: "session_window_start", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "session_window_end", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "session_window_status", Type: field.TypeString, Nullable: true, Size: 20},
@@ -121,7 +123,7 @@ var (
ForeignKeys: []*schema.ForeignKey{
{
Symbol: "accounts_proxies_proxy",
- Columns: []*schema.Column{AccountsColumns[25]},
+ Columns: []*schema.Column{AccountsColumns[27]},
RefColumns: []*schema.Column{ProxiesColumns[0]},
OnDelete: schema.SetNull,
},
@@ -145,7 +147,7 @@ var (
{
Name: "account_proxy_id",
Unique: false,
- Columns: []*schema.Column{AccountsColumns[25]},
+ Columns: []*schema.Column{AccountsColumns[27]},
},
{
Name: "account_priority",
@@ -177,6 +179,16 @@ var (
Unique: false,
Columns: []*schema.Column{AccountsColumns[21]},
},
+ {
+ Name: "account_platform_priority",
+ Unique: false,
+ Columns: []*schema.Column{AccountsColumns[6], AccountsColumns[11]},
+ },
+ {
+ Name: "account_priority_status",
+ Unique: false,
+ Columns: []*schema.Column{AccountsColumns[11], AccountsColumns[13]},
+ },
{
Name: "account_deleted_at",
Unique: false,
@@ -376,6 +388,7 @@ var (
{Name: "sora_image_price_540", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
{Name: "sora_video_price_per_request", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
{Name: "sora_video_price_per_request_hd", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
+ {Name: "sora_storage_quota_bytes", Type: field.TypeInt64, Default: 0},
{Name: "claude_code_only", Type: field.TypeBool, Default: false},
{Name: "fallback_group_id", Type: field.TypeInt64, Nullable: true},
{Name: "fallback_group_id_on_invalid_request", Type: field.TypeInt64, Nullable: true},
@@ -419,7 +432,45 @@ var (
{
Name: "group_sort_order",
Unique: false,
- Columns: []*schema.Column{GroupsColumns[29]},
+ Columns: []*schema.Column{GroupsColumns[30]},
+ },
+ },
+ }
+ // IdempotencyRecordsColumns holds the columns for the "idempotency_records" table.
+ IdempotencyRecordsColumns = []*schema.Column{
+ {Name: "id", Type: field.TypeInt64, Increment: true},
+ {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "scope", Type: field.TypeString, Size: 128},
+ {Name: "idempotency_key_hash", Type: field.TypeString, Size: 64},
+ {Name: "request_fingerprint", Type: field.TypeString, Size: 64},
+ {Name: "status", Type: field.TypeString, Size: 32},
+ {Name: "response_status", Type: field.TypeInt, Nullable: true},
+ {Name: "response_body", Type: field.TypeString, Nullable: true},
+ {Name: "error_reason", Type: field.TypeString, Nullable: true, Size: 128},
+ {Name: "locked_until", Type: field.TypeTime, Nullable: true},
+ {Name: "expires_at", Type: field.TypeTime},
+ }
+ // IdempotencyRecordsTable holds the schema information for the "idempotency_records" table.
+ IdempotencyRecordsTable = &schema.Table{
+ Name: "idempotency_records",
+ Columns: IdempotencyRecordsColumns,
+ PrimaryKey: []*schema.Column{IdempotencyRecordsColumns[0]},
+ Indexes: []*schema.Index{
+ {
+ Name: "idempotencyrecord_scope_idempotency_key_hash",
+ Unique: true,
+ Columns: []*schema.Column{IdempotencyRecordsColumns[3], IdempotencyRecordsColumns[4]},
+ },
+ {
+ Name: "idempotencyrecord_expires_at",
+ Unique: false,
+ Columns: []*schema.Column{IdempotencyRecordsColumns[11]},
+ },
+ {
+ Name: "idempotencyrecord_status_locked_until",
+ Unique: false,
+ Columns: []*schema.Column{IdempotencyRecordsColumns[6], IdempotencyRecordsColumns[10]},
},
},
}
@@ -771,6 +822,11 @@ var (
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[28], UsageLogsColumns[27]},
},
+ {
+ Name: "usagelog_group_id_created_at",
+ Unique: false,
+ Columns: []*schema.Column{UsageLogsColumns[30], UsageLogsColumns[27]},
+ },
},
}
// UsersColumns holds the columns for the "users" table.
@@ -790,6 +846,8 @@ var (
{Name: "totp_secret_encrypted", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}},
{Name: "totp_enabled", Type: field.TypeBool, Default: false},
{Name: "totp_enabled_at", Type: field.TypeTime, Nullable: true},
+ {Name: "sora_storage_quota_bytes", Type: field.TypeInt64, Default: 0},
+ {Name: "sora_storage_used_bytes", Type: field.TypeInt64, Default: 0},
}
// UsersTable holds the schema information for the "users" table.
UsersTable = &schema.Table{
@@ -995,6 +1053,11 @@ var (
Unique: false,
Columns: []*schema.Column{UserSubscriptionsColumns[5]},
},
+ {
+ Name: "usersubscription_user_id_status_expires_at",
+ Unique: false,
+ Columns: []*schema.Column{UserSubscriptionsColumns[16], UserSubscriptionsColumns[6], UserSubscriptionsColumns[5]},
+ },
{
Name: "usersubscription_assigned_by",
Unique: false,
@@ -1021,6 +1084,7 @@ var (
AnnouncementReadsTable,
ErrorPassthroughRulesTable,
GroupsTable,
+ IdempotencyRecordsTable,
PromoCodesTable,
PromoCodeUsagesTable,
ProxiesTable,
@@ -1066,6 +1130,9 @@ func init() {
GroupsTable.Annotation = &entsql.Annotation{
Table: "groups",
}
+ IdempotencyRecordsTable.Annotation = &entsql.Annotation{
+ Table: "idempotency_records",
+ }
PromoCodesTable.Annotation = &entsql.Annotation{
Table: "promo_codes",
}
diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go
index 7d5bf180..823cd389 100644
--- a/backend/ent/mutation.go
+++ b/backend/ent/mutation.go
@@ -19,6 +19,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule"
"github.com/Wei-Shaw/sub2api/ent/group"
+ "github.com/Wei-Shaw/sub2api/ent/idempotencyrecord"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/promocode"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
@@ -52,6 +53,7 @@ const (
TypeAnnouncementRead = "AnnouncementRead"
TypeErrorPassthroughRule = "ErrorPassthroughRule"
TypeGroup = "Group"
+ TypeIdempotencyRecord = "IdempotencyRecord"
TypePromoCode = "PromoCode"
TypePromoCodeUsage = "PromoCodeUsage"
TypeProxy = "Proxy"
@@ -1503,48 +1505,50 @@ func (m *APIKeyMutation) ResetEdge(name string) error {
// AccountMutation represents an operation that mutates the Account nodes in the graph.
type AccountMutation struct {
config
- op Op
- typ string
- id *int64
- created_at *time.Time
- updated_at *time.Time
- deleted_at *time.Time
- name *string
- notes *string
- platform *string
- _type *string
- credentials *map[string]interface{}
- extra *map[string]interface{}
- concurrency *int
- addconcurrency *int
- priority *int
- addpriority *int
- rate_multiplier *float64
- addrate_multiplier *float64
- status *string
- error_message *string
- last_used_at *time.Time
- expires_at *time.Time
- auto_pause_on_expired *bool
- schedulable *bool
- rate_limited_at *time.Time
- rate_limit_reset_at *time.Time
- overload_until *time.Time
- session_window_start *time.Time
- session_window_end *time.Time
- session_window_status *string
- clearedFields map[string]struct{}
- groups map[int64]struct{}
- removedgroups map[int64]struct{}
- clearedgroups bool
- proxy *int64
- clearedproxy bool
- usage_logs map[int64]struct{}
- removedusage_logs map[int64]struct{}
- clearedusage_logs bool
- done bool
- oldValue func(context.Context) (*Account, error)
- predicates []predicate.Account
+ op Op
+ typ string
+ id *int64
+ created_at *time.Time
+ updated_at *time.Time
+ deleted_at *time.Time
+ name *string
+ notes *string
+ platform *string
+ _type *string
+ credentials *map[string]interface{}
+ extra *map[string]interface{}
+ concurrency *int
+ addconcurrency *int
+ priority *int
+ addpriority *int
+ rate_multiplier *float64
+ addrate_multiplier *float64
+ status *string
+ error_message *string
+ last_used_at *time.Time
+ expires_at *time.Time
+ auto_pause_on_expired *bool
+ schedulable *bool
+ rate_limited_at *time.Time
+ rate_limit_reset_at *time.Time
+ overload_until *time.Time
+ temp_unschedulable_until *time.Time
+ temp_unschedulable_reason *string
+ session_window_start *time.Time
+ session_window_end *time.Time
+ session_window_status *string
+ clearedFields map[string]struct{}
+ groups map[int64]struct{}
+ removedgroups map[int64]struct{}
+ clearedgroups bool
+ proxy *int64
+ clearedproxy bool
+ usage_logs map[int64]struct{}
+ removedusage_logs map[int64]struct{}
+ clearedusage_logs bool
+ done bool
+ oldValue func(context.Context) (*Account, error)
+ predicates []predicate.Account
}
var _ ent.Mutation = (*AccountMutation)(nil)
@@ -2614,6 +2618,104 @@ func (m *AccountMutation) ResetOverloadUntil() {
delete(m.clearedFields, account.FieldOverloadUntil)
}
+// SetTempUnschedulableUntil sets the "temp_unschedulable_until" field.
+func (m *AccountMutation) SetTempUnschedulableUntil(t time.Time) {
+ m.temp_unschedulable_until = &t
+}
+
+// TempUnschedulableUntil returns the value of the "temp_unschedulable_until" field in the mutation.
+func (m *AccountMutation) TempUnschedulableUntil() (r time.Time, exists bool) {
+ v := m.temp_unschedulable_until
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldTempUnschedulableUntil returns the old "temp_unschedulable_until" field's value of the Account entity.
+// If the Account object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AccountMutation) OldTempUnschedulableUntil(ctx context.Context) (v *time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldTempUnschedulableUntil is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldTempUnschedulableUntil requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldTempUnschedulableUntil: %w", err)
+ }
+ return oldValue.TempUnschedulableUntil, nil
+}
+
+// ClearTempUnschedulableUntil clears the value of the "temp_unschedulable_until" field.
+func (m *AccountMutation) ClearTempUnschedulableUntil() {
+ m.temp_unschedulable_until = nil
+ m.clearedFields[account.FieldTempUnschedulableUntil] = struct{}{}
+}
+
+// TempUnschedulableUntilCleared returns if the "temp_unschedulable_until" field was cleared in this mutation.
+func (m *AccountMutation) TempUnschedulableUntilCleared() bool {
+ _, ok := m.clearedFields[account.FieldTempUnschedulableUntil]
+ return ok
+}
+
+// ResetTempUnschedulableUntil resets all changes to the "temp_unschedulable_until" field.
+func (m *AccountMutation) ResetTempUnschedulableUntil() {
+ m.temp_unschedulable_until = nil
+ delete(m.clearedFields, account.FieldTempUnschedulableUntil)
+}
+
+// SetTempUnschedulableReason sets the "temp_unschedulable_reason" field.
+func (m *AccountMutation) SetTempUnschedulableReason(s string) {
+ m.temp_unschedulable_reason = &s
+}
+
+// TempUnschedulableReason returns the value of the "temp_unschedulable_reason" field in the mutation.
+func (m *AccountMutation) TempUnschedulableReason() (r string, exists bool) {
+ v := m.temp_unschedulable_reason
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldTempUnschedulableReason returns the old "temp_unschedulable_reason" field's value of the Account entity.
+// If the Account object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AccountMutation) OldTempUnschedulableReason(ctx context.Context) (v *string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldTempUnschedulableReason is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldTempUnschedulableReason requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldTempUnschedulableReason: %w", err)
+ }
+ return oldValue.TempUnschedulableReason, nil
+}
+
+// ClearTempUnschedulableReason clears the value of the "temp_unschedulable_reason" field.
+func (m *AccountMutation) ClearTempUnschedulableReason() {
+ m.temp_unschedulable_reason = nil
+ m.clearedFields[account.FieldTempUnschedulableReason] = struct{}{}
+}
+
+// TempUnschedulableReasonCleared returns if the "temp_unschedulable_reason" field was cleared in this mutation.
+func (m *AccountMutation) TempUnschedulableReasonCleared() bool {
+ _, ok := m.clearedFields[account.FieldTempUnschedulableReason]
+ return ok
+}
+
+// ResetTempUnschedulableReason resets all changes to the "temp_unschedulable_reason" field.
+func (m *AccountMutation) ResetTempUnschedulableReason() {
+ m.temp_unschedulable_reason = nil
+ delete(m.clearedFields, account.FieldTempUnschedulableReason)
+}
+
// SetSessionWindowStart sets the "session_window_start" field.
func (m *AccountMutation) SetSessionWindowStart(t time.Time) {
m.session_window_start = &t
@@ -2930,7 +3032,7 @@ func (m *AccountMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *AccountMutation) Fields() []string {
- fields := make([]string, 0, 25)
+ fields := make([]string, 0, 27)
if m.created_at != nil {
fields = append(fields, account.FieldCreatedAt)
}
@@ -2997,6 +3099,12 @@ func (m *AccountMutation) Fields() []string {
if m.overload_until != nil {
fields = append(fields, account.FieldOverloadUntil)
}
+ if m.temp_unschedulable_until != nil {
+ fields = append(fields, account.FieldTempUnschedulableUntil)
+ }
+ if m.temp_unschedulable_reason != nil {
+ fields = append(fields, account.FieldTempUnschedulableReason)
+ }
if m.session_window_start != nil {
fields = append(fields, account.FieldSessionWindowStart)
}
@@ -3058,6 +3166,10 @@ func (m *AccountMutation) Field(name string) (ent.Value, bool) {
return m.RateLimitResetAt()
case account.FieldOverloadUntil:
return m.OverloadUntil()
+ case account.FieldTempUnschedulableUntil:
+ return m.TempUnschedulableUntil()
+ case account.FieldTempUnschedulableReason:
+ return m.TempUnschedulableReason()
case account.FieldSessionWindowStart:
return m.SessionWindowStart()
case account.FieldSessionWindowEnd:
@@ -3117,6 +3229,10 @@ func (m *AccountMutation) OldField(ctx context.Context, name string) (ent.Value,
return m.OldRateLimitResetAt(ctx)
case account.FieldOverloadUntil:
return m.OldOverloadUntil(ctx)
+ case account.FieldTempUnschedulableUntil:
+ return m.OldTempUnschedulableUntil(ctx)
+ case account.FieldTempUnschedulableReason:
+ return m.OldTempUnschedulableReason(ctx)
case account.FieldSessionWindowStart:
return m.OldSessionWindowStart(ctx)
case account.FieldSessionWindowEnd:
@@ -3286,6 +3402,20 @@ func (m *AccountMutation) SetField(name string, value ent.Value) error {
}
m.SetOverloadUntil(v)
return nil
+ case account.FieldTempUnschedulableUntil:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetTempUnschedulableUntil(v)
+ return nil
+ case account.FieldTempUnschedulableReason:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetTempUnschedulableReason(v)
+ return nil
case account.FieldSessionWindowStart:
v, ok := value.(time.Time)
if !ok {
@@ -3403,6 +3533,12 @@ func (m *AccountMutation) ClearedFields() []string {
if m.FieldCleared(account.FieldOverloadUntil) {
fields = append(fields, account.FieldOverloadUntil)
}
+ if m.FieldCleared(account.FieldTempUnschedulableUntil) {
+ fields = append(fields, account.FieldTempUnschedulableUntil)
+ }
+ if m.FieldCleared(account.FieldTempUnschedulableReason) {
+ fields = append(fields, account.FieldTempUnschedulableReason)
+ }
if m.FieldCleared(account.FieldSessionWindowStart) {
fields = append(fields, account.FieldSessionWindowStart)
}
@@ -3453,6 +3589,12 @@ func (m *AccountMutation) ClearField(name string) error {
case account.FieldOverloadUntil:
m.ClearOverloadUntil()
return nil
+ case account.FieldTempUnschedulableUntil:
+ m.ClearTempUnschedulableUntil()
+ return nil
+ case account.FieldTempUnschedulableReason:
+ m.ClearTempUnschedulableReason()
+ return nil
case account.FieldSessionWindowStart:
m.ClearSessionWindowStart()
return nil
@@ -3536,6 +3678,12 @@ func (m *AccountMutation) ResetField(name string) error {
case account.FieldOverloadUntil:
m.ResetOverloadUntil()
return nil
+ case account.FieldTempUnschedulableUntil:
+ m.ResetTempUnschedulableUntil()
+ return nil
+ case account.FieldTempUnschedulableReason:
+ m.ResetTempUnschedulableReason()
+ return nil
case account.FieldSessionWindowStart:
m.ResetSessionWindowStart()
return nil
@@ -7186,6 +7334,8 @@ type GroupMutation struct {
addsora_video_price_per_request *float64
sora_video_price_per_request_hd *float64
addsora_video_price_per_request_hd *float64
+ sora_storage_quota_bytes *int64
+ addsora_storage_quota_bytes *int64
claude_code_only *bool
fallback_group_id *int64
addfallback_group_id *int64
@@ -8482,6 +8632,62 @@ func (m *GroupMutation) ResetSoraVideoPricePerRequestHd() {
delete(m.clearedFields, group.FieldSoraVideoPricePerRequestHd)
}
+// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
+func (m *GroupMutation) SetSoraStorageQuotaBytes(i int64) {
+ m.sora_storage_quota_bytes = &i
+ m.addsora_storage_quota_bytes = nil
+}
+
+// SoraStorageQuotaBytes returns the value of the "sora_storage_quota_bytes" field in the mutation.
+func (m *GroupMutation) SoraStorageQuotaBytes() (r int64, exists bool) {
+ v := m.sora_storage_quota_bytes
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldSoraStorageQuotaBytes returns the old "sora_storage_quota_bytes" field's value of the Group entity.
+// If the Group object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *GroupMutation) OldSoraStorageQuotaBytes(ctx context.Context) (v int64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldSoraStorageQuotaBytes is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldSoraStorageQuotaBytes requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldSoraStorageQuotaBytes: %w", err)
+ }
+ return oldValue.SoraStorageQuotaBytes, nil
+}
+
+// AddSoraStorageQuotaBytes adds i to the "sora_storage_quota_bytes" field.
+func (m *GroupMutation) AddSoraStorageQuotaBytes(i int64) {
+ if m.addsora_storage_quota_bytes != nil {
+ *m.addsora_storage_quota_bytes += i
+ } else {
+ m.addsora_storage_quota_bytes = &i
+ }
+}
+
+// AddedSoraStorageQuotaBytes returns the value that was added to the "sora_storage_quota_bytes" field in this mutation.
+func (m *GroupMutation) AddedSoraStorageQuotaBytes() (r int64, exists bool) {
+ v := m.addsora_storage_quota_bytes
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ResetSoraStorageQuotaBytes resets all changes to the "sora_storage_quota_bytes" field.
+func (m *GroupMutation) ResetSoraStorageQuotaBytes() {
+ m.sora_storage_quota_bytes = nil
+ m.addsora_storage_quota_bytes = nil
+}
+
// SetClaudeCodeOnly sets the "claude_code_only" field.
func (m *GroupMutation) SetClaudeCodeOnly(b bool) {
m.claude_code_only = &b
@@ -9244,7 +9450,7 @@ func (m *GroupMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *GroupMutation) Fields() []string {
- fields := make([]string, 0, 29)
+ fields := make([]string, 0, 30)
if m.created_at != nil {
fields = append(fields, group.FieldCreatedAt)
}
@@ -9308,6 +9514,9 @@ func (m *GroupMutation) Fields() []string {
if m.sora_video_price_per_request_hd != nil {
fields = append(fields, group.FieldSoraVideoPricePerRequestHd)
}
+ if m.sora_storage_quota_bytes != nil {
+ fields = append(fields, group.FieldSoraStorageQuotaBytes)
+ }
if m.claude_code_only != nil {
fields = append(fields, group.FieldClaudeCodeOnly)
}
@@ -9382,6 +9591,8 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) {
return m.SoraVideoPricePerRequest()
case group.FieldSoraVideoPricePerRequestHd:
return m.SoraVideoPricePerRequestHd()
+ case group.FieldSoraStorageQuotaBytes:
+ return m.SoraStorageQuotaBytes()
case group.FieldClaudeCodeOnly:
return m.ClaudeCodeOnly()
case group.FieldFallbackGroupID:
@@ -9449,6 +9660,8 @@ func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, e
return m.OldSoraVideoPricePerRequest(ctx)
case group.FieldSoraVideoPricePerRequestHd:
return m.OldSoraVideoPricePerRequestHd(ctx)
+ case group.FieldSoraStorageQuotaBytes:
+ return m.OldSoraStorageQuotaBytes(ctx)
case group.FieldClaudeCodeOnly:
return m.OldClaudeCodeOnly(ctx)
case group.FieldFallbackGroupID:
@@ -9621,6 +9834,13 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error {
}
m.SetSoraVideoPricePerRequestHd(v)
return nil
+ case group.FieldSoraStorageQuotaBytes:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetSoraStorageQuotaBytes(v)
+ return nil
case group.FieldClaudeCodeOnly:
v, ok := value.(bool)
if !ok {
@@ -9721,6 +9941,9 @@ func (m *GroupMutation) AddedFields() []string {
if m.addsora_video_price_per_request_hd != nil {
fields = append(fields, group.FieldSoraVideoPricePerRequestHd)
}
+ if m.addsora_storage_quota_bytes != nil {
+ fields = append(fields, group.FieldSoraStorageQuotaBytes)
+ }
if m.addfallback_group_id != nil {
fields = append(fields, group.FieldFallbackGroupID)
}
@@ -9762,6 +9985,8 @@ func (m *GroupMutation) AddedField(name string) (ent.Value, bool) {
return m.AddedSoraVideoPricePerRequest()
case group.FieldSoraVideoPricePerRequestHd:
return m.AddedSoraVideoPricePerRequestHd()
+ case group.FieldSoraStorageQuotaBytes:
+ return m.AddedSoraStorageQuotaBytes()
case group.FieldFallbackGroupID:
return m.AddedFallbackGroupID()
case group.FieldFallbackGroupIDOnInvalidRequest:
@@ -9861,6 +10086,13 @@ func (m *GroupMutation) AddField(name string, value ent.Value) error {
}
m.AddSoraVideoPricePerRequestHd(v)
return nil
+ case group.FieldSoraStorageQuotaBytes:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddSoraStorageQuotaBytes(v)
+ return nil
case group.FieldFallbackGroupID:
v, ok := value.(int64)
if !ok {
@@ -10065,6 +10297,9 @@ func (m *GroupMutation) ResetField(name string) error {
case group.FieldSoraVideoPricePerRequestHd:
m.ResetSoraVideoPricePerRequestHd()
return nil
+ case group.FieldSoraStorageQuotaBytes:
+ m.ResetSoraStorageQuotaBytes()
+ return nil
case group.FieldClaudeCodeOnly:
m.ResetClaudeCodeOnly()
return nil
@@ -10307,6 +10542,988 @@ func (m *GroupMutation) ResetEdge(name string) error {
return fmt.Errorf("unknown Group edge %s", name)
}
+// IdempotencyRecordMutation represents an operation that mutates the IdempotencyRecord nodes in the graph.
+type IdempotencyRecordMutation struct {
+ config
+ op Op
+ typ string
+ id *int64
+ created_at *time.Time
+ updated_at *time.Time
+ scope *string
+ idempotency_key_hash *string
+ request_fingerprint *string
+ status *string
+ response_status *int
+ addresponse_status *int
+ response_body *string
+ error_reason *string
+ locked_until *time.Time
+ expires_at *time.Time
+ clearedFields map[string]struct{}
+ done bool
+ oldValue func(context.Context) (*IdempotencyRecord, error)
+ predicates []predicate.IdempotencyRecord
+}
+
+var _ ent.Mutation = (*IdempotencyRecordMutation)(nil)
+
+// idempotencyrecordOption allows management of the mutation configuration using functional options.
+type idempotencyrecordOption func(*IdempotencyRecordMutation)
+
+// newIdempotencyRecordMutation creates new mutation for the IdempotencyRecord entity.
+func newIdempotencyRecordMutation(c config, op Op, opts ...idempotencyrecordOption) *IdempotencyRecordMutation {
+ m := &IdempotencyRecordMutation{
+ config: c,
+ op: op,
+ typ: TypeIdempotencyRecord,
+ clearedFields: make(map[string]struct{}),
+ }
+ for _, opt := range opts {
+ opt(m)
+ }
+ return m
+}
+
+// withIdempotencyRecordID sets the ID field of the mutation.
+func withIdempotencyRecordID(id int64) idempotencyrecordOption {
+ return func(m *IdempotencyRecordMutation) {
+ var (
+ err error
+ once sync.Once
+ value *IdempotencyRecord
+ )
+ m.oldValue = func(ctx context.Context) (*IdempotencyRecord, error) {
+ once.Do(func() {
+ if m.done {
+ err = errors.New("querying old values post mutation is not allowed")
+ } else {
+ value, err = m.Client().IdempotencyRecord.Get(ctx, id)
+ }
+ })
+ return value, err
+ }
+ m.id = &id
+ }
+}
+
+// withIdempotencyRecord sets the old IdempotencyRecord of the mutation.
+func withIdempotencyRecord(node *IdempotencyRecord) idempotencyrecordOption {
+ return func(m *IdempotencyRecordMutation) {
+ m.oldValue = func(context.Context) (*IdempotencyRecord, error) {
+ return node, nil
+ }
+ m.id = &node.ID
+ }
+}
+
+// Client returns a new `ent.Client` from the mutation. If the mutation was
+// executed in a transaction (ent.Tx), a transactional client is returned.
+func (m IdempotencyRecordMutation) Client() *Client {
+ client := &Client{config: m.config}
+ client.init()
+ return client
+}
+
+// Tx returns an `ent.Tx` for mutations that were executed in transactions;
+// it returns an error otherwise.
+func (m IdempotencyRecordMutation) Tx() (*Tx, error) {
+ if _, ok := m.driver.(*txDriver); !ok {
+ return nil, errors.New("ent: mutation is not running in a transaction")
+ }
+ tx := &Tx{config: m.config}
+ tx.init()
+ return tx, nil
+}
+
+// ID returns the ID value in the mutation. Note that the ID is only available
+// if it was provided to the builder or after it was returned from the database.
+func (m *IdempotencyRecordMutation) ID() (id int64, exists bool) {
+ if m.id == nil {
+ return
+ }
+ return *m.id, true
+}
+
+// IDs queries the database and returns the entity ids that match the mutation's predicate.
+// That means, if the mutation is applied within a transaction with an isolation level such
+// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated
+// or updated by the mutation.
+func (m *IdempotencyRecordMutation) IDs(ctx context.Context) ([]int64, error) {
+ switch {
+ case m.op.Is(OpUpdateOne | OpDeleteOne):
+ id, exists := m.ID()
+ if exists {
+ return []int64{id}, nil
+ }
+ fallthrough
+ case m.op.Is(OpUpdate | OpDelete):
+ return m.Client().IdempotencyRecord.Query().Where(m.predicates...).IDs(ctx)
+ default:
+ return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op)
+ }
+}
+
+// SetCreatedAt sets the "created_at" field.
+func (m *IdempotencyRecordMutation) SetCreatedAt(t time.Time) {
+ m.created_at = &t
+}
+
+// CreatedAt returns the value of the "created_at" field in the mutation.
+func (m *IdempotencyRecordMutation) CreatedAt() (r time.Time, exists bool) {
+ v := m.created_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldCreatedAt returns the old "created_at" field's value of the IdempotencyRecord entity.
+// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *IdempotencyRecordMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldCreatedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err)
+ }
+ return oldValue.CreatedAt, nil
+}
+
+// ResetCreatedAt resets all changes to the "created_at" field.
+func (m *IdempotencyRecordMutation) ResetCreatedAt() {
+ m.created_at = nil
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (m *IdempotencyRecordMutation) SetUpdatedAt(t time.Time) {
+ m.updated_at = &t
+}
+
+// UpdatedAt returns the value of the "updated_at" field in the mutation.
+func (m *IdempotencyRecordMutation) UpdatedAt() (r time.Time, exists bool) {
+ v := m.updated_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldUpdatedAt returns the old "updated_at" field's value of the IdempotencyRecord entity.
+// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *IdempotencyRecordMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldUpdatedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err)
+ }
+ return oldValue.UpdatedAt, nil
+}
+
+// ResetUpdatedAt resets all changes to the "updated_at" field.
+func (m *IdempotencyRecordMutation) ResetUpdatedAt() {
+ m.updated_at = nil
+}
+
+// SetScope sets the "scope" field.
+func (m *IdempotencyRecordMutation) SetScope(s string) {
+ m.scope = &s
+}
+
+// Scope returns the value of the "scope" field in the mutation.
+func (m *IdempotencyRecordMutation) Scope() (r string, exists bool) {
+ v := m.scope
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldScope returns the old "scope" field's value of the IdempotencyRecord entity.
+// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *IdempotencyRecordMutation) OldScope(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldScope is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldScope requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldScope: %w", err)
+ }
+ return oldValue.Scope, nil
+}
+
+// ResetScope resets all changes to the "scope" field.
+func (m *IdempotencyRecordMutation) ResetScope() {
+ m.scope = nil
+}
+
+// SetIdempotencyKeyHash sets the "idempotency_key_hash" field.
+func (m *IdempotencyRecordMutation) SetIdempotencyKeyHash(s string) {
+ m.idempotency_key_hash = &s
+}
+
+// IdempotencyKeyHash returns the value of the "idempotency_key_hash" field in the mutation.
+func (m *IdempotencyRecordMutation) IdempotencyKeyHash() (r string, exists bool) {
+ v := m.idempotency_key_hash
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldIdempotencyKeyHash returns the old "idempotency_key_hash" field's value of the IdempotencyRecord entity.
+// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *IdempotencyRecordMutation) OldIdempotencyKeyHash(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldIdempotencyKeyHash is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldIdempotencyKeyHash requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldIdempotencyKeyHash: %w", err)
+ }
+ return oldValue.IdempotencyKeyHash, nil
+}
+
+// ResetIdempotencyKeyHash resets all changes to the "idempotency_key_hash" field.
+func (m *IdempotencyRecordMutation) ResetIdempotencyKeyHash() {
+ m.idempotency_key_hash = nil
+}
+
+// SetRequestFingerprint sets the "request_fingerprint" field.
+func (m *IdempotencyRecordMutation) SetRequestFingerprint(s string) {
+ m.request_fingerprint = &s
+}
+
+// RequestFingerprint returns the value of the "request_fingerprint" field in the mutation.
+func (m *IdempotencyRecordMutation) RequestFingerprint() (r string, exists bool) {
+ v := m.request_fingerprint
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldRequestFingerprint returns the old "request_fingerprint" field's value of the IdempotencyRecord entity.
+// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *IdempotencyRecordMutation) OldRequestFingerprint(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldRequestFingerprint is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldRequestFingerprint requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldRequestFingerprint: %w", err)
+ }
+ return oldValue.RequestFingerprint, nil
+}
+
+// ResetRequestFingerprint resets all changes to the "request_fingerprint" field.
+func (m *IdempotencyRecordMutation) ResetRequestFingerprint() {
+ m.request_fingerprint = nil
+}
+
+// SetStatus sets the "status" field.
+func (m *IdempotencyRecordMutation) SetStatus(s string) {
+ m.status = &s
+}
+
+// Status returns the value of the "status" field in the mutation.
+func (m *IdempotencyRecordMutation) Status() (r string, exists bool) {
+ v := m.status
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldStatus returns the old "status" field's value of the IdempotencyRecord entity.
+// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *IdempotencyRecordMutation) OldStatus(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldStatus is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldStatus requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldStatus: %w", err)
+ }
+ return oldValue.Status, nil
+}
+
+// ResetStatus resets all changes to the "status" field.
+func (m *IdempotencyRecordMutation) ResetStatus() {
+ m.status = nil
+}
+
+// SetResponseStatus sets the "response_status" field.
+func (m *IdempotencyRecordMutation) SetResponseStatus(i int) {
+ m.response_status = &i
+ m.addresponse_status = nil
+}
+
+// ResponseStatus returns the value of the "response_status" field in the mutation.
+func (m *IdempotencyRecordMutation) ResponseStatus() (r int, exists bool) {
+ v := m.response_status
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldResponseStatus returns the old "response_status" field's value of the IdempotencyRecord entity.
+// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *IdempotencyRecordMutation) OldResponseStatus(ctx context.Context) (v *int, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldResponseStatus is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldResponseStatus requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldResponseStatus: %w", err)
+ }
+ return oldValue.ResponseStatus, nil
+}
+
+// AddResponseStatus adds i to the "response_status" field.
+func (m *IdempotencyRecordMutation) AddResponseStatus(i int) {
+ if m.addresponse_status != nil {
+ *m.addresponse_status += i
+ } else {
+ m.addresponse_status = &i
+ }
+}
+
+// AddedResponseStatus returns the value that was added to the "response_status" field in this mutation.
+func (m *IdempotencyRecordMutation) AddedResponseStatus() (r int, exists bool) {
+ v := m.addresponse_status
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ClearResponseStatus clears the value of the "response_status" field.
+func (m *IdempotencyRecordMutation) ClearResponseStatus() {
+ m.response_status = nil
+ m.addresponse_status = nil
+ m.clearedFields[idempotencyrecord.FieldResponseStatus] = struct{}{}
+}
+
+// ResponseStatusCleared returns if the "response_status" field was cleared in this mutation.
+func (m *IdempotencyRecordMutation) ResponseStatusCleared() bool {
+ _, ok := m.clearedFields[idempotencyrecord.FieldResponseStatus]
+ return ok
+}
+
+// ResetResponseStatus resets all changes to the "response_status" field.
+func (m *IdempotencyRecordMutation) ResetResponseStatus() {
+ m.response_status = nil
+ m.addresponse_status = nil
+ delete(m.clearedFields, idempotencyrecord.FieldResponseStatus)
+}
+
+// SetResponseBody sets the "response_body" field.
+func (m *IdempotencyRecordMutation) SetResponseBody(s string) {
+ m.response_body = &s
+}
+
+// ResponseBody returns the value of the "response_body" field in the mutation.
+func (m *IdempotencyRecordMutation) ResponseBody() (r string, exists bool) {
+ v := m.response_body
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldResponseBody returns the old "response_body" field's value of the IdempotencyRecord entity.
+// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *IdempotencyRecordMutation) OldResponseBody(ctx context.Context) (v *string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldResponseBody is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldResponseBody requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldResponseBody: %w", err)
+ }
+ return oldValue.ResponseBody, nil
+}
+
+// ClearResponseBody clears the value of the "response_body" field.
+func (m *IdempotencyRecordMutation) ClearResponseBody() {
+ m.response_body = nil
+ m.clearedFields[idempotencyrecord.FieldResponseBody] = struct{}{}
+}
+
+// ResponseBodyCleared returns if the "response_body" field was cleared in this mutation.
+func (m *IdempotencyRecordMutation) ResponseBodyCleared() bool {
+ _, ok := m.clearedFields[idempotencyrecord.FieldResponseBody]
+ return ok
+}
+
+// ResetResponseBody resets all changes to the "response_body" field.
+func (m *IdempotencyRecordMutation) ResetResponseBody() {
+ m.response_body = nil
+ delete(m.clearedFields, idempotencyrecord.FieldResponseBody)
+}
+
+// SetErrorReason sets the "error_reason" field.
+func (m *IdempotencyRecordMutation) SetErrorReason(s string) {
+ m.error_reason = &s
+}
+
+// ErrorReason returns the value of the "error_reason" field in the mutation.
+func (m *IdempotencyRecordMutation) ErrorReason() (r string, exists bool) {
+ v := m.error_reason
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldErrorReason returns the old "error_reason" field's value of the IdempotencyRecord entity.
+// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *IdempotencyRecordMutation) OldErrorReason(ctx context.Context) (v *string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldErrorReason is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldErrorReason requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldErrorReason: %w", err)
+ }
+ return oldValue.ErrorReason, nil
+}
+
+// ClearErrorReason clears the value of the "error_reason" field.
+func (m *IdempotencyRecordMutation) ClearErrorReason() {
+ m.error_reason = nil
+ m.clearedFields[idempotencyrecord.FieldErrorReason] = struct{}{}
+}
+
+// ErrorReasonCleared returns if the "error_reason" field was cleared in this mutation.
+func (m *IdempotencyRecordMutation) ErrorReasonCleared() bool {
+ _, ok := m.clearedFields[idempotencyrecord.FieldErrorReason]
+ return ok
+}
+
+// ResetErrorReason resets all changes to the "error_reason" field.
+func (m *IdempotencyRecordMutation) ResetErrorReason() {
+ m.error_reason = nil
+ delete(m.clearedFields, idempotencyrecord.FieldErrorReason)
+}
+
+// SetLockedUntil sets the "locked_until" field.
+func (m *IdempotencyRecordMutation) SetLockedUntil(t time.Time) {
+ m.locked_until = &t
+}
+
+// LockedUntil returns the value of the "locked_until" field in the mutation.
+func (m *IdempotencyRecordMutation) LockedUntil() (r time.Time, exists bool) {
+ v := m.locked_until
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldLockedUntil returns the old "locked_until" field's value of the IdempotencyRecord entity.
+// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *IdempotencyRecordMutation) OldLockedUntil(ctx context.Context) (v *time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldLockedUntil is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldLockedUntil requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldLockedUntil: %w", err)
+ }
+ return oldValue.LockedUntil, nil
+}
+
+// ClearLockedUntil clears the value of the "locked_until" field.
+func (m *IdempotencyRecordMutation) ClearLockedUntil() {
+ m.locked_until = nil
+ m.clearedFields[idempotencyrecord.FieldLockedUntil] = struct{}{}
+}
+
+// LockedUntilCleared returns if the "locked_until" field was cleared in this mutation.
+func (m *IdempotencyRecordMutation) LockedUntilCleared() bool {
+ _, ok := m.clearedFields[idempotencyrecord.FieldLockedUntil]
+ return ok
+}
+
+// ResetLockedUntil resets all changes to the "locked_until" field.
+func (m *IdempotencyRecordMutation) ResetLockedUntil() {
+ m.locked_until = nil
+ delete(m.clearedFields, idempotencyrecord.FieldLockedUntil)
+}
+
+// SetExpiresAt sets the "expires_at" field.
+func (m *IdempotencyRecordMutation) SetExpiresAt(t time.Time) {
+ m.expires_at = &t
+}
+
+// ExpiresAt returns the value of the "expires_at" field in the mutation.
+func (m *IdempotencyRecordMutation) ExpiresAt() (r time.Time, exists bool) {
+ v := m.expires_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldExpiresAt returns the old "expires_at" field's value of the IdempotencyRecord entity.
+// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *IdempotencyRecordMutation) OldExpiresAt(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldExpiresAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldExpiresAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldExpiresAt: %w", err)
+ }
+ return oldValue.ExpiresAt, nil
+}
+
+// ResetExpiresAt resets all changes to the "expires_at" field.
+func (m *IdempotencyRecordMutation) ResetExpiresAt() {
+ m.expires_at = nil
+}
+
+// Where appends a list predicates to the IdempotencyRecordMutation builder.
+func (m *IdempotencyRecordMutation) Where(ps ...predicate.IdempotencyRecord) {
+ m.predicates = append(m.predicates, ps...)
+}
+
+// WhereP appends storage-level predicates to the IdempotencyRecordMutation builder. Using this method,
+// users can use type-assertion to append predicates that do not depend on any generated package.
+func (m *IdempotencyRecordMutation) WhereP(ps ...func(*sql.Selector)) {
+ p := make([]predicate.IdempotencyRecord, len(ps))
+ for i := range ps {
+ p[i] = ps[i]
+ }
+ m.Where(p...)
+}
+
+// Op returns the operation name.
+func (m *IdempotencyRecordMutation) Op() Op {
+ return m.op
+}
+
+// SetOp allows setting the mutation operation.
+func (m *IdempotencyRecordMutation) SetOp(op Op) {
+ m.op = op
+}
+
+// Type returns the node type of this mutation (IdempotencyRecord).
+func (m *IdempotencyRecordMutation) Type() string {
+ return m.typ
+}
+
+// Fields returns all fields that were changed during this mutation. Note that in
+// order to get all numeric fields that were incremented/decremented, call
+// AddedFields().
+func (m *IdempotencyRecordMutation) Fields() []string {
+ fields := make([]string, 0, 11)
+ if m.created_at != nil {
+ fields = append(fields, idempotencyrecord.FieldCreatedAt)
+ }
+ if m.updated_at != nil {
+ fields = append(fields, idempotencyrecord.FieldUpdatedAt)
+ }
+ if m.scope != nil {
+ fields = append(fields, idempotencyrecord.FieldScope)
+ }
+ if m.idempotency_key_hash != nil {
+ fields = append(fields, idempotencyrecord.FieldIdempotencyKeyHash)
+ }
+ if m.request_fingerprint != nil {
+ fields = append(fields, idempotencyrecord.FieldRequestFingerprint)
+ }
+ if m.status != nil {
+ fields = append(fields, idempotencyrecord.FieldStatus)
+ }
+ if m.response_status != nil {
+ fields = append(fields, idempotencyrecord.FieldResponseStatus)
+ }
+ if m.response_body != nil {
+ fields = append(fields, idempotencyrecord.FieldResponseBody)
+ }
+ if m.error_reason != nil {
+ fields = append(fields, idempotencyrecord.FieldErrorReason)
+ }
+ if m.locked_until != nil {
+ fields = append(fields, idempotencyrecord.FieldLockedUntil)
+ }
+ if m.expires_at != nil {
+ fields = append(fields, idempotencyrecord.FieldExpiresAt)
+ }
+ return fields
+}
+
+// Field returns the value of a field with the given name. The second boolean
+// return value indicates that this field was not set, or was not defined in the
+// schema.
+func (m *IdempotencyRecordMutation) Field(name string) (ent.Value, bool) {
+ switch name {
+ case idempotencyrecord.FieldCreatedAt:
+ return m.CreatedAt()
+ case idempotencyrecord.FieldUpdatedAt:
+ return m.UpdatedAt()
+ case idempotencyrecord.FieldScope:
+ return m.Scope()
+ case idempotencyrecord.FieldIdempotencyKeyHash:
+ return m.IdempotencyKeyHash()
+ case idempotencyrecord.FieldRequestFingerprint:
+ return m.RequestFingerprint()
+ case idempotencyrecord.FieldStatus:
+ return m.Status()
+ case idempotencyrecord.FieldResponseStatus:
+ return m.ResponseStatus()
+ case idempotencyrecord.FieldResponseBody:
+ return m.ResponseBody()
+ case idempotencyrecord.FieldErrorReason:
+ return m.ErrorReason()
+ case idempotencyrecord.FieldLockedUntil:
+ return m.LockedUntil()
+ case idempotencyrecord.FieldExpiresAt:
+ return m.ExpiresAt()
+ }
+ return nil, false
+}
+
+// OldField returns the old value of the field from the database. An error is
+// returned if the mutation operation is not UpdateOne, or the query to the
+// database failed.
+func (m *IdempotencyRecordMutation) OldField(ctx context.Context, name string) (ent.Value, error) {
+ switch name {
+ case idempotencyrecord.FieldCreatedAt:
+ return m.OldCreatedAt(ctx)
+ case idempotencyrecord.FieldUpdatedAt:
+ return m.OldUpdatedAt(ctx)
+ case idempotencyrecord.FieldScope:
+ return m.OldScope(ctx)
+ case idempotencyrecord.FieldIdempotencyKeyHash:
+ return m.OldIdempotencyKeyHash(ctx)
+ case idempotencyrecord.FieldRequestFingerprint:
+ return m.OldRequestFingerprint(ctx)
+ case idempotencyrecord.FieldStatus:
+ return m.OldStatus(ctx)
+ case idempotencyrecord.FieldResponseStatus:
+ return m.OldResponseStatus(ctx)
+ case idempotencyrecord.FieldResponseBody:
+ return m.OldResponseBody(ctx)
+ case idempotencyrecord.FieldErrorReason:
+ return m.OldErrorReason(ctx)
+ case idempotencyrecord.FieldLockedUntil:
+ return m.OldLockedUntil(ctx)
+ case idempotencyrecord.FieldExpiresAt:
+ return m.OldExpiresAt(ctx)
+ }
+ return nil, fmt.Errorf("unknown IdempotencyRecord field %s", name)
+}
+
+// SetField sets the value of a field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *IdempotencyRecordMutation) SetField(name string, value ent.Value) error {
+ switch name {
+ case idempotencyrecord.FieldCreatedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetCreatedAt(v)
+ return nil
+ case idempotencyrecord.FieldUpdatedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetUpdatedAt(v)
+ return nil
+ case idempotencyrecord.FieldScope:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetScope(v)
+ return nil
+ case idempotencyrecord.FieldIdempotencyKeyHash:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetIdempotencyKeyHash(v)
+ return nil
+ case idempotencyrecord.FieldRequestFingerprint:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetRequestFingerprint(v)
+ return nil
+ case idempotencyrecord.FieldStatus:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetStatus(v)
+ return nil
+ case idempotencyrecord.FieldResponseStatus:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetResponseStatus(v)
+ return nil
+ case idempotencyrecord.FieldResponseBody:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetResponseBody(v)
+ return nil
+ case idempotencyrecord.FieldErrorReason:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetErrorReason(v)
+ return nil
+ case idempotencyrecord.FieldLockedUntil:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetLockedUntil(v)
+ return nil
+ case idempotencyrecord.FieldExpiresAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetExpiresAt(v)
+ return nil
+ }
+ return fmt.Errorf("unknown IdempotencyRecord field %s", name)
+}
+
+// AddedFields returns all numeric fields that were incremented/decremented during
+// this mutation.
+func (m *IdempotencyRecordMutation) AddedFields() []string {
+ var fields []string
+ if m.addresponse_status != nil {
+ fields = append(fields, idempotencyrecord.FieldResponseStatus)
+ }
+ return fields
+}
+
+// AddedField returns the numeric value that was incremented/decremented on a field
+// with the given name. The second boolean return value indicates that this field
+// was not set, or was not defined in the schema.
+func (m *IdempotencyRecordMutation) AddedField(name string) (ent.Value, bool) {
+ switch name {
+ case idempotencyrecord.FieldResponseStatus:
+ return m.AddedResponseStatus()
+ }
+ return nil, false
+}
+
+// AddField adds the value to the field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *IdempotencyRecordMutation) AddField(name string, value ent.Value) error {
+ switch name {
+ case idempotencyrecord.FieldResponseStatus:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddResponseStatus(v)
+ return nil
+ }
+ return fmt.Errorf("unknown IdempotencyRecord numeric field %s", name)
+}
+
+// ClearedFields returns all nullable fields that were cleared during this
+// mutation.
+func (m *IdempotencyRecordMutation) ClearedFields() []string {
+ var fields []string
+ if m.FieldCleared(idempotencyrecord.FieldResponseStatus) {
+ fields = append(fields, idempotencyrecord.FieldResponseStatus)
+ }
+ if m.FieldCleared(idempotencyrecord.FieldResponseBody) {
+ fields = append(fields, idempotencyrecord.FieldResponseBody)
+ }
+ if m.FieldCleared(idempotencyrecord.FieldErrorReason) {
+ fields = append(fields, idempotencyrecord.FieldErrorReason)
+ }
+ if m.FieldCleared(idempotencyrecord.FieldLockedUntil) {
+ fields = append(fields, idempotencyrecord.FieldLockedUntil)
+ }
+ return fields
+}
+
+// FieldCleared returns a boolean indicating if a field with the given name was
+// cleared in this mutation.
+func (m *IdempotencyRecordMutation) FieldCleared(name string) bool {
+ _, ok := m.clearedFields[name]
+ return ok
+}
+
+// ClearField clears the value of the field with the given name. It returns an
+// error if the field is not defined in the schema.
+func (m *IdempotencyRecordMutation) ClearField(name string) error {
+ switch name {
+ case idempotencyrecord.FieldResponseStatus:
+ m.ClearResponseStatus()
+ return nil
+ case idempotencyrecord.FieldResponseBody:
+ m.ClearResponseBody()
+ return nil
+ case idempotencyrecord.FieldErrorReason:
+ m.ClearErrorReason()
+ return nil
+ case idempotencyrecord.FieldLockedUntil:
+ m.ClearLockedUntil()
+ return nil
+ }
+ return fmt.Errorf("unknown IdempotencyRecord nullable field %s", name)
+}
+
+// ResetField resets all changes in the mutation for the field with the given name.
+// It returns an error if the field is not defined in the schema.
+func (m *IdempotencyRecordMutation) ResetField(name string) error {
+ switch name {
+ case idempotencyrecord.FieldCreatedAt:
+ m.ResetCreatedAt()
+ return nil
+ case idempotencyrecord.FieldUpdatedAt:
+ m.ResetUpdatedAt()
+ return nil
+ case idempotencyrecord.FieldScope:
+ m.ResetScope()
+ return nil
+ case idempotencyrecord.FieldIdempotencyKeyHash:
+ m.ResetIdempotencyKeyHash()
+ return nil
+ case idempotencyrecord.FieldRequestFingerprint:
+ m.ResetRequestFingerprint()
+ return nil
+ case idempotencyrecord.FieldStatus:
+ m.ResetStatus()
+ return nil
+ case idempotencyrecord.FieldResponseStatus:
+ m.ResetResponseStatus()
+ return nil
+ case idempotencyrecord.FieldResponseBody:
+ m.ResetResponseBody()
+ return nil
+ case idempotencyrecord.FieldErrorReason:
+ m.ResetErrorReason()
+ return nil
+ case idempotencyrecord.FieldLockedUntil:
+ m.ResetLockedUntil()
+ return nil
+ case idempotencyrecord.FieldExpiresAt:
+ m.ResetExpiresAt()
+ return nil
+ }
+ return fmt.Errorf("unknown IdempotencyRecord field %s", name)
+}
+
+// AddedEdges returns all edge names that were set/added in this mutation.
+func (m *IdempotencyRecordMutation) AddedEdges() []string {
+ edges := make([]string, 0, 0)
+ return edges
+}
+
+// AddedIDs returns all IDs (to other nodes) that were added for the given edge
+// name in this mutation.
+func (m *IdempotencyRecordMutation) AddedIDs(name string) []ent.Value {
+ return nil
+}
+
+// RemovedEdges returns all edge names that were removed in this mutation.
+func (m *IdempotencyRecordMutation) RemovedEdges() []string {
+ edges := make([]string, 0, 0)
+ return edges
+}
+
+// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with
+// the given name in this mutation.
+func (m *IdempotencyRecordMutation) RemovedIDs(name string) []ent.Value {
+ return nil
+}
+
+// ClearedEdges returns all edge names that were cleared in this mutation.
+func (m *IdempotencyRecordMutation) ClearedEdges() []string {
+ edges := make([]string, 0, 0)
+ return edges
+}
+
+// EdgeCleared returns a boolean which indicates if the edge with the given name
+// was cleared in this mutation.
+func (m *IdempotencyRecordMutation) EdgeCleared(name string) bool {
+ return false
+}
+
+// ClearEdge clears the value of the edge with the given name. It returns an error
+// if that edge is not defined in the schema.
+func (m *IdempotencyRecordMutation) ClearEdge(name string) error {
+ return fmt.Errorf("unknown IdempotencyRecord unique edge %s", name)
+}
+
+// ResetEdge resets all changes to the edge with the given name in this mutation.
+// It returns an error if the edge is not defined in the schema.
+func (m *IdempotencyRecordMutation) ResetEdge(name string) error {
+ return fmt.Errorf("unknown IdempotencyRecord edge %s", name)
+}
+
// PromoCodeMutation represents an operation that mutates the PromoCode nodes in the graph.
type PromoCodeMutation struct {
config
@@ -19038,6 +20255,10 @@ type UserMutation struct {
totp_secret_encrypted *string
totp_enabled *bool
totp_enabled_at *time.Time
+ sora_storage_quota_bytes *int64
+ addsora_storage_quota_bytes *int64
+ sora_storage_used_bytes *int64
+ addsora_storage_used_bytes *int64
clearedFields map[string]struct{}
api_keys map[int64]struct{}
removedapi_keys map[int64]struct{}
@@ -19752,6 +20973,118 @@ func (m *UserMutation) ResetTotpEnabledAt() {
delete(m.clearedFields, user.FieldTotpEnabledAt)
}
+// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
+func (m *UserMutation) SetSoraStorageQuotaBytes(i int64) {
+ m.sora_storage_quota_bytes = &i
+ m.addsora_storage_quota_bytes = nil
+}
+
+// SoraStorageQuotaBytes returns the value of the "sora_storage_quota_bytes" field in the mutation.
+func (m *UserMutation) SoraStorageQuotaBytes() (r int64, exists bool) {
+ v := m.sora_storage_quota_bytes
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldSoraStorageQuotaBytes returns the old "sora_storage_quota_bytes" field's value of the User entity.
+// If the User object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *UserMutation) OldSoraStorageQuotaBytes(ctx context.Context) (v int64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldSoraStorageQuotaBytes is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldSoraStorageQuotaBytes requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldSoraStorageQuotaBytes: %w", err)
+ }
+ return oldValue.SoraStorageQuotaBytes, nil
+}
+
+// AddSoraStorageQuotaBytes adds i to the "sora_storage_quota_bytes" field.
+func (m *UserMutation) AddSoraStorageQuotaBytes(i int64) {
+ if m.addsora_storage_quota_bytes != nil {
+ *m.addsora_storage_quota_bytes += i
+ } else {
+ m.addsora_storage_quota_bytes = &i
+ }
+}
+
+// AddedSoraStorageQuotaBytes returns the value that was added to the "sora_storage_quota_bytes" field in this mutation.
+func (m *UserMutation) AddedSoraStorageQuotaBytes() (r int64, exists bool) {
+ v := m.addsora_storage_quota_bytes
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ResetSoraStorageQuotaBytes resets all changes to the "sora_storage_quota_bytes" field.
+func (m *UserMutation) ResetSoraStorageQuotaBytes() {
+ m.sora_storage_quota_bytes = nil
+ m.addsora_storage_quota_bytes = nil
+}
+
+// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field.
+func (m *UserMutation) SetSoraStorageUsedBytes(i int64) {
+ m.sora_storage_used_bytes = &i
+ m.addsora_storage_used_bytes = nil
+}
+
+// SoraStorageUsedBytes returns the value of the "sora_storage_used_bytes" field in the mutation.
+func (m *UserMutation) SoraStorageUsedBytes() (r int64, exists bool) {
+ v := m.sora_storage_used_bytes
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldSoraStorageUsedBytes returns the old "sora_storage_used_bytes" field's value of the User entity.
+// If the User object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *UserMutation) OldSoraStorageUsedBytes(ctx context.Context) (v int64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldSoraStorageUsedBytes is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldSoraStorageUsedBytes requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldSoraStorageUsedBytes: %w", err)
+ }
+ return oldValue.SoraStorageUsedBytes, nil
+}
+
+// AddSoraStorageUsedBytes adds i to the "sora_storage_used_bytes" field.
+func (m *UserMutation) AddSoraStorageUsedBytes(i int64) {
+ if m.addsora_storage_used_bytes != nil {
+ *m.addsora_storage_used_bytes += i
+ } else {
+ m.addsora_storage_used_bytes = &i
+ }
+}
+
+// AddedSoraStorageUsedBytes returns the value that was added to the "sora_storage_used_bytes" field in this mutation.
+func (m *UserMutation) AddedSoraStorageUsedBytes() (r int64, exists bool) {
+ v := m.addsora_storage_used_bytes
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ResetSoraStorageUsedBytes resets all changes to the "sora_storage_used_bytes" field.
+func (m *UserMutation) ResetSoraStorageUsedBytes() {
+ m.sora_storage_used_bytes = nil
+ m.addsora_storage_used_bytes = nil
+}
+
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids.
func (m *UserMutation) AddAPIKeyIDs(ids ...int64) {
if m.api_keys == nil {
@@ -20272,7 +21605,7 @@ func (m *UserMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *UserMutation) Fields() []string {
- fields := make([]string, 0, 14)
+ fields := make([]string, 0, 16)
if m.created_at != nil {
fields = append(fields, user.FieldCreatedAt)
}
@@ -20315,6 +21648,12 @@ func (m *UserMutation) Fields() []string {
if m.totp_enabled_at != nil {
fields = append(fields, user.FieldTotpEnabledAt)
}
+ if m.sora_storage_quota_bytes != nil {
+ fields = append(fields, user.FieldSoraStorageQuotaBytes)
+ }
+ if m.sora_storage_used_bytes != nil {
+ fields = append(fields, user.FieldSoraStorageUsedBytes)
+ }
return fields
}
@@ -20351,6 +21690,10 @@ func (m *UserMutation) Field(name string) (ent.Value, bool) {
return m.TotpEnabled()
case user.FieldTotpEnabledAt:
return m.TotpEnabledAt()
+ case user.FieldSoraStorageQuotaBytes:
+ return m.SoraStorageQuotaBytes()
+ case user.FieldSoraStorageUsedBytes:
+ return m.SoraStorageUsedBytes()
}
return nil, false
}
@@ -20388,6 +21731,10 @@ func (m *UserMutation) OldField(ctx context.Context, name string) (ent.Value, er
return m.OldTotpEnabled(ctx)
case user.FieldTotpEnabledAt:
return m.OldTotpEnabledAt(ctx)
+ case user.FieldSoraStorageQuotaBytes:
+ return m.OldSoraStorageQuotaBytes(ctx)
+ case user.FieldSoraStorageUsedBytes:
+ return m.OldSoraStorageUsedBytes(ctx)
}
return nil, fmt.Errorf("unknown User field %s", name)
}
@@ -20495,6 +21842,20 @@ func (m *UserMutation) SetField(name string, value ent.Value) error {
}
m.SetTotpEnabledAt(v)
return nil
+ case user.FieldSoraStorageQuotaBytes:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetSoraStorageQuotaBytes(v)
+ return nil
+ case user.FieldSoraStorageUsedBytes:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetSoraStorageUsedBytes(v)
+ return nil
}
return fmt.Errorf("unknown User field %s", name)
}
@@ -20509,6 +21870,12 @@ func (m *UserMutation) AddedFields() []string {
if m.addconcurrency != nil {
fields = append(fields, user.FieldConcurrency)
}
+ if m.addsora_storage_quota_bytes != nil {
+ fields = append(fields, user.FieldSoraStorageQuotaBytes)
+ }
+ if m.addsora_storage_used_bytes != nil {
+ fields = append(fields, user.FieldSoraStorageUsedBytes)
+ }
return fields
}
@@ -20521,6 +21888,10 @@ func (m *UserMutation) AddedField(name string) (ent.Value, bool) {
return m.AddedBalance()
case user.FieldConcurrency:
return m.AddedConcurrency()
+ case user.FieldSoraStorageQuotaBytes:
+ return m.AddedSoraStorageQuotaBytes()
+ case user.FieldSoraStorageUsedBytes:
+ return m.AddedSoraStorageUsedBytes()
}
return nil, false
}
@@ -20544,6 +21915,20 @@ func (m *UserMutation) AddField(name string, value ent.Value) error {
}
m.AddConcurrency(v)
return nil
+ case user.FieldSoraStorageQuotaBytes:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddSoraStorageQuotaBytes(v)
+ return nil
+ case user.FieldSoraStorageUsedBytes:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddSoraStorageUsedBytes(v)
+ return nil
}
return fmt.Errorf("unknown User numeric field %s", name)
}
@@ -20634,6 +22019,12 @@ func (m *UserMutation) ResetField(name string) error {
case user.FieldTotpEnabledAt:
m.ResetTotpEnabledAt()
return nil
+ case user.FieldSoraStorageQuotaBytes:
+ m.ResetSoraStorageQuotaBytes()
+ return nil
+ case user.FieldSoraStorageUsedBytes:
+ m.ResetSoraStorageUsedBytes()
+ return nil
}
return fmt.Errorf("unknown User field %s", name)
}
diff --git a/backend/ent/predicate/predicate.go b/backend/ent/predicate/predicate.go
index 584b9606..89d933fc 100644
--- a/backend/ent/predicate/predicate.go
+++ b/backend/ent/predicate/predicate.go
@@ -27,6 +27,9 @@ type ErrorPassthroughRule func(*sql.Selector)
// Group is the predicate function for group builders.
type Group func(*sql.Selector)
+// IdempotencyRecord is the predicate function for idempotencyrecord builders.
+type IdempotencyRecord func(*sql.Selector)
+
// PromoCode is the predicate function for promocode builders.
type PromoCode func(*sql.Selector)
diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go
index ff3f8f26..65531aae 100644
--- a/backend/ent/runtime/runtime.go
+++ b/backend/ent/runtime/runtime.go
@@ -12,6 +12,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule"
"github.com/Wei-Shaw/sub2api/ent/group"
+ "github.com/Wei-Shaw/sub2api/ent/idempotencyrecord"
"github.com/Wei-Shaw/sub2api/ent/promocode"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
"github.com/Wei-Shaw/sub2api/ent/proxy"
@@ -209,7 +210,7 @@ func init() {
// account.DefaultSchedulable holds the default value on creation for the schedulable field.
account.DefaultSchedulable = accountDescSchedulable.Default.(bool)
// accountDescSessionWindowStatus is the schema descriptor for session_window_status field.
- accountDescSessionWindowStatus := accountFields[21].Descriptor()
+ accountDescSessionWindowStatus := accountFields[23].Descriptor()
// account.SessionWindowStatusValidator is a validator for the "session_window_status" field. It is called by the builders before save.
account.SessionWindowStatusValidator = accountDescSessionWindowStatus.Validators[0].(func(string) error)
accountgroupFields := schema.AccountGroup{}.Fields()
@@ -398,26 +399,65 @@ func init() {
groupDescDefaultValidityDays := groupFields[10].Descriptor()
// group.DefaultDefaultValidityDays holds the default value on creation for the default_validity_days field.
group.DefaultDefaultValidityDays = groupDescDefaultValidityDays.Default.(int)
+ // groupDescSoraStorageQuotaBytes is the schema descriptor for sora_storage_quota_bytes field.
+ groupDescSoraStorageQuotaBytes := groupFields[18].Descriptor()
+ // group.DefaultSoraStorageQuotaBytes holds the default value on creation for the sora_storage_quota_bytes field.
+ group.DefaultSoraStorageQuotaBytes = groupDescSoraStorageQuotaBytes.Default.(int64)
// groupDescClaudeCodeOnly is the schema descriptor for claude_code_only field.
- groupDescClaudeCodeOnly := groupFields[18].Descriptor()
+ groupDescClaudeCodeOnly := groupFields[19].Descriptor()
// group.DefaultClaudeCodeOnly holds the default value on creation for the claude_code_only field.
group.DefaultClaudeCodeOnly = groupDescClaudeCodeOnly.Default.(bool)
// groupDescModelRoutingEnabled is the schema descriptor for model_routing_enabled field.
- groupDescModelRoutingEnabled := groupFields[22].Descriptor()
+ groupDescModelRoutingEnabled := groupFields[23].Descriptor()
// group.DefaultModelRoutingEnabled holds the default value on creation for the model_routing_enabled field.
group.DefaultModelRoutingEnabled = groupDescModelRoutingEnabled.Default.(bool)
// groupDescMcpXMLInject is the schema descriptor for mcp_xml_inject field.
- groupDescMcpXMLInject := groupFields[23].Descriptor()
+ groupDescMcpXMLInject := groupFields[24].Descriptor()
// group.DefaultMcpXMLInject holds the default value on creation for the mcp_xml_inject field.
group.DefaultMcpXMLInject = groupDescMcpXMLInject.Default.(bool)
// groupDescSupportedModelScopes is the schema descriptor for supported_model_scopes field.
- groupDescSupportedModelScopes := groupFields[24].Descriptor()
+ groupDescSupportedModelScopes := groupFields[25].Descriptor()
// group.DefaultSupportedModelScopes holds the default value on creation for the supported_model_scopes field.
group.DefaultSupportedModelScopes = groupDescSupportedModelScopes.Default.([]string)
// groupDescSortOrder is the schema descriptor for sort_order field.
- groupDescSortOrder := groupFields[25].Descriptor()
+ groupDescSortOrder := groupFields[26].Descriptor()
// group.DefaultSortOrder holds the default value on creation for the sort_order field.
group.DefaultSortOrder = groupDescSortOrder.Default.(int)
+ idempotencyrecordMixin := schema.IdempotencyRecord{}.Mixin()
+ idempotencyrecordMixinFields0 := idempotencyrecordMixin[0].Fields()
+ _ = idempotencyrecordMixinFields0
+ idempotencyrecordFields := schema.IdempotencyRecord{}.Fields()
+ _ = idempotencyrecordFields
+ // idempotencyrecordDescCreatedAt is the schema descriptor for created_at field.
+ idempotencyrecordDescCreatedAt := idempotencyrecordMixinFields0[0].Descriptor()
+ // idempotencyrecord.DefaultCreatedAt holds the default value on creation for the created_at field.
+ idempotencyrecord.DefaultCreatedAt = idempotencyrecordDescCreatedAt.Default.(func() time.Time)
+ // idempotencyrecordDescUpdatedAt is the schema descriptor for updated_at field.
+ idempotencyrecordDescUpdatedAt := idempotencyrecordMixinFields0[1].Descriptor()
+ // idempotencyrecord.DefaultUpdatedAt holds the default value on creation for the updated_at field.
+ idempotencyrecord.DefaultUpdatedAt = idempotencyrecordDescUpdatedAt.Default.(func() time.Time)
+ // idempotencyrecord.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
+ idempotencyrecord.UpdateDefaultUpdatedAt = idempotencyrecordDescUpdatedAt.UpdateDefault.(func() time.Time)
+ // idempotencyrecordDescScope is the schema descriptor for scope field.
+ idempotencyrecordDescScope := idempotencyrecordFields[0].Descriptor()
+ // idempotencyrecord.ScopeValidator is a validator for the "scope" field. It is called by the builders before save.
+ idempotencyrecord.ScopeValidator = idempotencyrecordDescScope.Validators[0].(func(string) error)
+ // idempotencyrecordDescIdempotencyKeyHash is the schema descriptor for idempotency_key_hash field.
+ idempotencyrecordDescIdempotencyKeyHash := idempotencyrecordFields[1].Descriptor()
+ // idempotencyrecord.IdempotencyKeyHashValidator is a validator for the "idempotency_key_hash" field. It is called by the builders before save.
+ idempotencyrecord.IdempotencyKeyHashValidator = idempotencyrecordDescIdempotencyKeyHash.Validators[0].(func(string) error)
+ // idempotencyrecordDescRequestFingerprint is the schema descriptor for request_fingerprint field.
+ idempotencyrecordDescRequestFingerprint := idempotencyrecordFields[2].Descriptor()
+ // idempotencyrecord.RequestFingerprintValidator is a validator for the "request_fingerprint" field. It is called by the builders before save.
+ idempotencyrecord.RequestFingerprintValidator = idempotencyrecordDescRequestFingerprint.Validators[0].(func(string) error)
+ // idempotencyrecordDescStatus is the schema descriptor for status field.
+ idempotencyrecordDescStatus := idempotencyrecordFields[3].Descriptor()
+ // idempotencyrecord.StatusValidator is a validator for the "status" field. It is called by the builders before save.
+ idempotencyrecord.StatusValidator = idempotencyrecordDescStatus.Validators[0].(func(string) error)
+ // idempotencyrecordDescErrorReason is the schema descriptor for error_reason field.
+ idempotencyrecordDescErrorReason := idempotencyrecordFields[6].Descriptor()
+ // idempotencyrecord.ErrorReasonValidator is a validator for the "error_reason" field. It is called by the builders before save.
+ idempotencyrecord.ErrorReasonValidator = idempotencyrecordDescErrorReason.Validators[0].(func(string) error)
promocodeFields := schema.PromoCode{}.Fields()
_ = promocodeFields
// promocodeDescCode is the schema descriptor for code field.
@@ -918,6 +958,14 @@ func init() {
userDescTotpEnabled := userFields[9].Descriptor()
// user.DefaultTotpEnabled holds the default value on creation for the totp_enabled field.
user.DefaultTotpEnabled = userDescTotpEnabled.Default.(bool)
+ // userDescSoraStorageQuotaBytes is the schema descriptor for sora_storage_quota_bytes field.
+ userDescSoraStorageQuotaBytes := userFields[11].Descriptor()
+ // user.DefaultSoraStorageQuotaBytes holds the default value on creation for the sora_storage_quota_bytes field.
+ user.DefaultSoraStorageQuotaBytes = userDescSoraStorageQuotaBytes.Default.(int64)
+ // userDescSoraStorageUsedBytes is the schema descriptor for sora_storage_used_bytes field.
+ userDescSoraStorageUsedBytes := userFields[12].Descriptor()
+ // user.DefaultSoraStorageUsedBytes holds the default value on creation for the sora_storage_used_bytes field.
+ user.DefaultSoraStorageUsedBytes = userDescSoraStorageUsedBytes.Default.(int64)
userallowedgroupFields := schema.UserAllowedGroup{}.Fields()
_ = userallowedgroupFields
// userallowedgroupDescCreatedAt is the schema descriptor for created_at field.
diff --git a/backend/ent/schema/account.go b/backend/ent/schema/account.go
index 1cfecc2d..443f9e09 100644
--- a/backend/ent/schema/account.go
+++ b/backend/ent/schema/account.go
@@ -164,6 +164,19 @@ func (Account) Fields() []ent.Field {
Nillable().
SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
+ // temp_unschedulable_until: 临时不可调度状态解除时间
+ // 当命中临时不可调度规则时设置,在此时间前调度器应跳过该账号
+ field.Time("temp_unschedulable_until").
+ Optional().
+ Nillable().
+ SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
+
+ // temp_unschedulable_reason: 临时不可调度原因,便于排障审计
+ field.String("temp_unschedulable_reason").
+ Optional().
+ Nillable().
+ SchemaType(map[string]string{dialect.Postgres: "text"}),
+
// session_window_*: 会话窗口相关字段
// 用于管理某些需要会话时间窗口的 API(如 Claude Pro)
field.Time("session_window_start").
@@ -213,6 +226,9 @@ func (Account) Indexes() []ent.Index {
index.Fields("rate_limited_at"), // 筛选速率限制账户
index.Fields("rate_limit_reset_at"), // 筛选速率限制解除时间
index.Fields("overload_until"), // 筛选过载账户
- index.Fields("deleted_at"), // 软删除查询优化
+ // 调度热路径复合索引(线上由 SQL 迁移创建部分索引,schema 仅用于模型可读性对齐)
+ index.Fields("platform", "priority"),
+ index.Fields("priority", "status"),
+ index.Fields("deleted_at"), // 软删除查询优化
}
}
diff --git a/backend/ent/schema/group.go b/backend/ent/schema/group.go
index fddf23ce..3fcf8674 100644
--- a/backend/ent/schema/group.go
+++ b/backend/ent/schema/group.go
@@ -105,6 +105,10 @@ func (Group) Fields() []ent.Field {
Nillable().
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}),
+ // Sora 存储配额
+ field.Int64("sora_storage_quota_bytes").
+ Default(0),
+
// Claude Code 客户端限制 (added by migration 029)
field.Bool("claude_code_only").
Default(false).
diff --git a/backend/ent/schema/usage_log.go b/backend/ent/schema/usage_log.go
index ffcae840..dcca1a0a 100644
--- a/backend/ent/schema/usage_log.go
+++ b/backend/ent/schema/usage_log.go
@@ -179,5 +179,7 @@ func (UsageLog) Indexes() []ent.Index {
// 复合索引用于时间范围查询
index.Fields("user_id", "created_at"),
index.Fields("api_key_id", "created_at"),
+ // 分组维度时间范围查询(线上由 SQL 迁移创建 group_id IS NOT NULL 的部分索引)
+ index.Fields("group_id", "created_at"),
}
}
diff --git a/backend/ent/schema/user.go b/backend/ent/schema/user.go
index d443ef45..0a3b5d9e 100644
--- a/backend/ent/schema/user.go
+++ b/backend/ent/schema/user.go
@@ -72,6 +72,12 @@ func (User) Fields() []ent.Field {
field.Time("totp_enabled_at").
Optional().
Nillable(),
+
+ // Sora 存储配额
+ field.Int64("sora_storage_quota_bytes").
+ Default(0),
+ field.Int64("sora_storage_used_bytes").
+ Default(0),
}
}
diff --git a/backend/ent/schema/user_subscription.go b/backend/ent/schema/user_subscription.go
index fa13612b..a81850b1 100644
--- a/backend/ent/schema/user_subscription.go
+++ b/backend/ent/schema/user_subscription.go
@@ -108,6 +108,8 @@ func (UserSubscription) Indexes() []ent.Index {
index.Fields("group_id"),
index.Fields("status"),
index.Fields("expires_at"),
+ // 活跃订阅查询复合索引(线上由 SQL 迁移创建部分索引,schema 仅用于模型可读性对齐)
+ index.Fields("user_id", "status", "expires_at"),
index.Fields("assigned_by"),
// 唯一约束通过部分索引实现(WHERE deleted_at IS NULL),支持软删除后重新订阅
// 见迁移文件 016_soft_delete_partial_unique_indexes.sql
diff --git a/backend/ent/tx.go b/backend/ent/tx.go
index 4fbe9bb4..cd3b2296 100644
--- a/backend/ent/tx.go
+++ b/backend/ent/tx.go
@@ -28,6 +28,8 @@ type Tx struct {
ErrorPassthroughRule *ErrorPassthroughRuleClient
// Group is the client for interacting with the Group builders.
Group *GroupClient
+ // IdempotencyRecord is the client for interacting with the IdempotencyRecord builders.
+ IdempotencyRecord *IdempotencyRecordClient
// PromoCode is the client for interacting with the PromoCode builders.
PromoCode *PromoCodeClient
// PromoCodeUsage is the client for interacting with the PromoCodeUsage builders.
@@ -192,6 +194,7 @@ func (tx *Tx) init() {
tx.AnnouncementRead = NewAnnouncementReadClient(tx.config)
tx.ErrorPassthroughRule = NewErrorPassthroughRuleClient(tx.config)
tx.Group = NewGroupClient(tx.config)
+ tx.IdempotencyRecord = NewIdempotencyRecordClient(tx.config)
tx.PromoCode = NewPromoCodeClient(tx.config)
tx.PromoCodeUsage = NewPromoCodeUsageClient(tx.config)
tx.Proxy = NewProxyClient(tx.config)
diff --git a/backend/ent/user.go b/backend/ent/user.go
index 2435aa1b..b3f933f6 100644
--- a/backend/ent/user.go
+++ b/backend/ent/user.go
@@ -45,6 +45,10 @@ type User struct {
TotpEnabled bool `json:"totp_enabled,omitempty"`
// TotpEnabledAt holds the value of the "totp_enabled_at" field.
TotpEnabledAt *time.Time `json:"totp_enabled_at,omitempty"`
+ // SoraStorageQuotaBytes holds the value of the "sora_storage_quota_bytes" field.
+ SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes,omitempty"`
+ // SoraStorageUsedBytes holds the value of the "sora_storage_used_bytes" field.
+ SoraStorageUsedBytes int64 `json:"sora_storage_used_bytes,omitempty"`
// Edges holds the relations/edges for other nodes in the graph.
// The values are being populated by the UserQuery when eager-loading is set.
Edges UserEdges `json:"edges"`
@@ -177,7 +181,7 @@ func (*User) scanValues(columns []string) ([]any, error) {
values[i] = new(sql.NullBool)
case user.FieldBalance:
values[i] = new(sql.NullFloat64)
- case user.FieldID, user.FieldConcurrency:
+ case user.FieldID, user.FieldConcurrency, user.FieldSoraStorageQuotaBytes, user.FieldSoraStorageUsedBytes:
values[i] = new(sql.NullInt64)
case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldNotes, user.FieldTotpSecretEncrypted:
values[i] = new(sql.NullString)
@@ -291,6 +295,18 @@ func (_m *User) assignValues(columns []string, values []any) error {
_m.TotpEnabledAt = new(time.Time)
*_m.TotpEnabledAt = value.Time
}
+ case user.FieldSoraStorageQuotaBytes:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field sora_storage_quota_bytes", values[i])
+ } else if value.Valid {
+ _m.SoraStorageQuotaBytes = value.Int64
+ }
+ case user.FieldSoraStorageUsedBytes:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field sora_storage_used_bytes", values[i])
+ } else if value.Valid {
+ _m.SoraStorageUsedBytes = value.Int64
+ }
default:
_m.selectValues.Set(columns[i], values[i])
}
@@ -424,6 +440,12 @@ func (_m *User) String() string {
builder.WriteString("totp_enabled_at=")
builder.WriteString(v.Format(time.ANSIC))
}
+ builder.WriteString(", ")
+ builder.WriteString("sora_storage_quota_bytes=")
+ builder.WriteString(fmt.Sprintf("%v", _m.SoraStorageQuotaBytes))
+ builder.WriteString(", ")
+ builder.WriteString("sora_storage_used_bytes=")
+ builder.WriteString(fmt.Sprintf("%v", _m.SoraStorageUsedBytes))
builder.WriteByte(')')
return builder.String()
}
diff --git a/backend/ent/user/user.go b/backend/ent/user/user.go
index ae9418ff..155b9160 100644
--- a/backend/ent/user/user.go
+++ b/backend/ent/user/user.go
@@ -43,6 +43,10 @@ const (
FieldTotpEnabled = "totp_enabled"
// FieldTotpEnabledAt holds the string denoting the totp_enabled_at field in the database.
FieldTotpEnabledAt = "totp_enabled_at"
+ // FieldSoraStorageQuotaBytes holds the string denoting the sora_storage_quota_bytes field in the database.
+ FieldSoraStorageQuotaBytes = "sora_storage_quota_bytes"
+ // FieldSoraStorageUsedBytes holds the string denoting the sora_storage_used_bytes field in the database.
+ FieldSoraStorageUsedBytes = "sora_storage_used_bytes"
// EdgeAPIKeys holds the string denoting the api_keys edge name in mutations.
EdgeAPIKeys = "api_keys"
// EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations.
@@ -152,6 +156,8 @@ var Columns = []string{
FieldTotpSecretEncrypted,
FieldTotpEnabled,
FieldTotpEnabledAt,
+ FieldSoraStorageQuotaBytes,
+ FieldSoraStorageUsedBytes,
}
var (
@@ -208,6 +214,10 @@ var (
DefaultNotes string
// DefaultTotpEnabled holds the default value on creation for the "totp_enabled" field.
DefaultTotpEnabled bool
+ // DefaultSoraStorageQuotaBytes holds the default value on creation for the "sora_storage_quota_bytes" field.
+ DefaultSoraStorageQuotaBytes int64
+ // DefaultSoraStorageUsedBytes holds the default value on creation for the "sora_storage_used_bytes" field.
+ DefaultSoraStorageUsedBytes int64
)
// OrderOption defines the ordering options for the User queries.
@@ -288,6 +298,16 @@ func ByTotpEnabledAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldTotpEnabledAt, opts...).ToFunc()
}
+// BySoraStorageQuotaBytes orders the results by the sora_storage_quota_bytes field.
+func BySoraStorageQuotaBytes(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldSoraStorageQuotaBytes, opts...).ToFunc()
+}
+
+// BySoraStorageUsedBytes orders the results by the sora_storage_used_bytes field.
+func BySoraStorageUsedBytes(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldSoraStorageUsedBytes, opts...).ToFunc()
+}
+
// ByAPIKeysCount orders the results by api_keys count.
func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
diff --git a/backend/ent/user/where.go b/backend/ent/user/where.go
index 1de61037..e26afcf3 100644
--- a/backend/ent/user/where.go
+++ b/backend/ent/user/where.go
@@ -125,6 +125,16 @@ func TotpEnabledAt(v time.Time) predicate.User {
return predicate.User(sql.FieldEQ(FieldTotpEnabledAt, v))
}
+// SoraStorageQuotaBytes applies equality check predicate on the "sora_storage_quota_bytes" field. It's identical to SoraStorageQuotaBytesEQ.
+func SoraStorageQuotaBytes(v int64) predicate.User {
+ return predicate.User(sql.FieldEQ(FieldSoraStorageQuotaBytes, v))
+}
+
+// SoraStorageUsedBytes applies equality check predicate on the "sora_storage_used_bytes" field. It's identical to SoraStorageUsedBytesEQ.
+func SoraStorageUsedBytes(v int64) predicate.User {
+ return predicate.User(sql.FieldEQ(FieldSoraStorageUsedBytes, v))
+}
+
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
func CreatedAtEQ(v time.Time) predicate.User {
return predicate.User(sql.FieldEQ(FieldCreatedAt, v))
@@ -860,6 +870,86 @@ func TotpEnabledAtNotNil() predicate.User {
return predicate.User(sql.FieldNotNull(FieldTotpEnabledAt))
}
+// SoraStorageQuotaBytesEQ applies the EQ predicate on the "sora_storage_quota_bytes" field.
+func SoraStorageQuotaBytesEQ(v int64) predicate.User {
+ return predicate.User(sql.FieldEQ(FieldSoraStorageQuotaBytes, v))
+}
+
+// SoraStorageQuotaBytesNEQ applies the NEQ predicate on the "sora_storage_quota_bytes" field.
+func SoraStorageQuotaBytesNEQ(v int64) predicate.User {
+ return predicate.User(sql.FieldNEQ(FieldSoraStorageQuotaBytes, v))
+}
+
+// SoraStorageQuotaBytesIn applies the In predicate on the "sora_storage_quota_bytes" field.
+func SoraStorageQuotaBytesIn(vs ...int64) predicate.User {
+ return predicate.User(sql.FieldIn(FieldSoraStorageQuotaBytes, vs...))
+}
+
+// SoraStorageQuotaBytesNotIn applies the NotIn predicate on the "sora_storage_quota_bytes" field.
+func SoraStorageQuotaBytesNotIn(vs ...int64) predicate.User {
+ return predicate.User(sql.FieldNotIn(FieldSoraStorageQuotaBytes, vs...))
+}
+
+// SoraStorageQuotaBytesGT applies the GT predicate on the "sora_storage_quota_bytes" field.
+func SoraStorageQuotaBytesGT(v int64) predicate.User {
+ return predicate.User(sql.FieldGT(FieldSoraStorageQuotaBytes, v))
+}
+
+// SoraStorageQuotaBytesGTE applies the GTE predicate on the "sora_storage_quota_bytes" field.
+func SoraStorageQuotaBytesGTE(v int64) predicate.User {
+ return predicate.User(sql.FieldGTE(FieldSoraStorageQuotaBytes, v))
+}
+
+// SoraStorageQuotaBytesLT applies the LT predicate on the "sora_storage_quota_bytes" field.
+func SoraStorageQuotaBytesLT(v int64) predicate.User {
+ return predicate.User(sql.FieldLT(FieldSoraStorageQuotaBytes, v))
+}
+
+// SoraStorageQuotaBytesLTE applies the LTE predicate on the "sora_storage_quota_bytes" field.
+func SoraStorageQuotaBytesLTE(v int64) predicate.User {
+ return predicate.User(sql.FieldLTE(FieldSoraStorageQuotaBytes, v))
+}
+
+// SoraStorageUsedBytesEQ applies the EQ predicate on the "sora_storage_used_bytes" field.
+func SoraStorageUsedBytesEQ(v int64) predicate.User {
+ return predicate.User(sql.FieldEQ(FieldSoraStorageUsedBytes, v))
+}
+
+// SoraStorageUsedBytesNEQ applies the NEQ predicate on the "sora_storage_used_bytes" field.
+func SoraStorageUsedBytesNEQ(v int64) predicate.User {
+ return predicate.User(sql.FieldNEQ(FieldSoraStorageUsedBytes, v))
+}
+
+// SoraStorageUsedBytesIn applies the In predicate on the "sora_storage_used_bytes" field.
+func SoraStorageUsedBytesIn(vs ...int64) predicate.User {
+ return predicate.User(sql.FieldIn(FieldSoraStorageUsedBytes, vs...))
+}
+
+// SoraStorageUsedBytesNotIn applies the NotIn predicate on the "sora_storage_used_bytes" field.
+func SoraStorageUsedBytesNotIn(vs ...int64) predicate.User {
+ return predicate.User(sql.FieldNotIn(FieldSoraStorageUsedBytes, vs...))
+}
+
+// SoraStorageUsedBytesGT applies the GT predicate on the "sora_storage_used_bytes" field.
+func SoraStorageUsedBytesGT(v int64) predicate.User {
+ return predicate.User(sql.FieldGT(FieldSoraStorageUsedBytes, v))
+}
+
+// SoraStorageUsedBytesGTE applies the GTE predicate on the "sora_storage_used_bytes" field.
+func SoraStorageUsedBytesGTE(v int64) predicate.User {
+ return predicate.User(sql.FieldGTE(FieldSoraStorageUsedBytes, v))
+}
+
+// SoraStorageUsedBytesLT applies the LT predicate on the "sora_storage_used_bytes" field.
+func SoraStorageUsedBytesLT(v int64) predicate.User {
+ return predicate.User(sql.FieldLT(FieldSoraStorageUsedBytes, v))
+}
+
+// SoraStorageUsedBytesLTE applies the LTE predicate on the "sora_storage_used_bytes" field.
+func SoraStorageUsedBytesLTE(v int64) predicate.User {
+ return predicate.User(sql.FieldLTE(FieldSoraStorageUsedBytes, v))
+}
+
// HasAPIKeys applies the HasEdge predicate on the "api_keys" edge.
func HasAPIKeys() predicate.User {
return predicate.User(func(s *sql.Selector) {
diff --git a/backend/ent/user_create.go b/backend/ent/user_create.go
index f862a580..df0c6bcc 100644
--- a/backend/ent/user_create.go
+++ b/backend/ent/user_create.go
@@ -210,6 +210,34 @@ func (_c *UserCreate) SetNillableTotpEnabledAt(v *time.Time) *UserCreate {
return _c
}
+// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
+func (_c *UserCreate) SetSoraStorageQuotaBytes(v int64) *UserCreate {
+ _c.mutation.SetSoraStorageQuotaBytes(v)
+ return _c
+}
+
+// SetNillableSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field if the given value is not nil.
+func (_c *UserCreate) SetNillableSoraStorageQuotaBytes(v *int64) *UserCreate {
+ if v != nil {
+ _c.SetSoraStorageQuotaBytes(*v)
+ }
+ return _c
+}
+
+// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field.
+func (_c *UserCreate) SetSoraStorageUsedBytes(v int64) *UserCreate {
+ _c.mutation.SetSoraStorageUsedBytes(v)
+ return _c
+}
+
+// SetNillableSoraStorageUsedBytes sets the "sora_storage_used_bytes" field if the given value is not nil.
+func (_c *UserCreate) SetNillableSoraStorageUsedBytes(v *int64) *UserCreate {
+ if v != nil {
+ _c.SetSoraStorageUsedBytes(*v)
+ }
+ return _c
+}
+
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_c *UserCreate) AddAPIKeyIDs(ids ...int64) *UserCreate {
_c.mutation.AddAPIKeyIDs(ids...)
@@ -424,6 +452,14 @@ func (_c *UserCreate) defaults() error {
v := user.DefaultTotpEnabled
_c.mutation.SetTotpEnabled(v)
}
+ if _, ok := _c.mutation.SoraStorageQuotaBytes(); !ok {
+ v := user.DefaultSoraStorageQuotaBytes
+ _c.mutation.SetSoraStorageQuotaBytes(v)
+ }
+ if _, ok := _c.mutation.SoraStorageUsedBytes(); !ok {
+ v := user.DefaultSoraStorageUsedBytes
+ _c.mutation.SetSoraStorageUsedBytes(v)
+ }
return nil
}
@@ -487,6 +523,12 @@ func (_c *UserCreate) check() error {
if _, ok := _c.mutation.TotpEnabled(); !ok {
return &ValidationError{Name: "totp_enabled", err: errors.New(`ent: missing required field "User.totp_enabled"`)}
}
+ if _, ok := _c.mutation.SoraStorageQuotaBytes(); !ok {
+ return &ValidationError{Name: "sora_storage_quota_bytes", err: errors.New(`ent: missing required field "User.sora_storage_quota_bytes"`)}
+ }
+ if _, ok := _c.mutation.SoraStorageUsedBytes(); !ok {
+ return &ValidationError{Name: "sora_storage_used_bytes", err: errors.New(`ent: missing required field "User.sora_storage_used_bytes"`)}
+ }
return nil
}
@@ -570,6 +612,14 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) {
_spec.SetField(user.FieldTotpEnabledAt, field.TypeTime, value)
_node.TotpEnabledAt = &value
}
+ if value, ok := _c.mutation.SoraStorageQuotaBytes(); ok {
+ _spec.SetField(user.FieldSoraStorageQuotaBytes, field.TypeInt64, value)
+ _node.SoraStorageQuotaBytes = value
+ }
+ if value, ok := _c.mutation.SoraStorageUsedBytes(); ok {
+ _spec.SetField(user.FieldSoraStorageUsedBytes, field.TypeInt64, value)
+ _node.SoraStorageUsedBytes = value
+ }
if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
@@ -956,6 +1006,42 @@ func (u *UserUpsert) ClearTotpEnabledAt() *UserUpsert {
return u
}
+// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
+func (u *UserUpsert) SetSoraStorageQuotaBytes(v int64) *UserUpsert {
+ u.Set(user.FieldSoraStorageQuotaBytes, v)
+ return u
+}
+
+// UpdateSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field to the value that was provided on create.
+func (u *UserUpsert) UpdateSoraStorageQuotaBytes() *UserUpsert {
+ u.SetExcluded(user.FieldSoraStorageQuotaBytes)
+ return u
+}
+
+// AddSoraStorageQuotaBytes adds v to the "sora_storage_quota_bytes" field.
+func (u *UserUpsert) AddSoraStorageQuotaBytes(v int64) *UserUpsert {
+ u.Add(user.FieldSoraStorageQuotaBytes, v)
+ return u
+}
+
+// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field.
+func (u *UserUpsert) SetSoraStorageUsedBytes(v int64) *UserUpsert {
+ u.Set(user.FieldSoraStorageUsedBytes, v)
+ return u
+}
+
+// UpdateSoraStorageUsedBytes sets the "sora_storage_used_bytes" field to the value that was provided on create.
+func (u *UserUpsert) UpdateSoraStorageUsedBytes() *UserUpsert {
+ u.SetExcluded(user.FieldSoraStorageUsedBytes)
+ return u
+}
+
+// AddSoraStorageUsedBytes adds v to the "sora_storage_used_bytes" field.
+func (u *UserUpsert) AddSoraStorageUsedBytes(v int64) *UserUpsert {
+ u.Add(user.FieldSoraStorageUsedBytes, v)
+ return u
+}
+
// UpdateNewValues updates the mutable fields using the new values that were set on create.
// Using this option is equivalent to using:
//
@@ -1218,6 +1304,48 @@ func (u *UserUpsertOne) ClearTotpEnabledAt() *UserUpsertOne {
})
}
+// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
+func (u *UserUpsertOne) SetSoraStorageQuotaBytes(v int64) *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.SetSoraStorageQuotaBytes(v)
+ })
+}
+
+// AddSoraStorageQuotaBytes adds v to the "sora_storage_quota_bytes" field.
+func (u *UserUpsertOne) AddSoraStorageQuotaBytes(v int64) *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.AddSoraStorageQuotaBytes(v)
+ })
+}
+
+// UpdateSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field to the value that was provided on create.
+func (u *UserUpsertOne) UpdateSoraStorageQuotaBytes() *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.UpdateSoraStorageQuotaBytes()
+ })
+}
+
+// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field.
+func (u *UserUpsertOne) SetSoraStorageUsedBytes(v int64) *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.SetSoraStorageUsedBytes(v)
+ })
+}
+
+// AddSoraStorageUsedBytes adds v to the "sora_storage_used_bytes" field.
+func (u *UserUpsertOne) AddSoraStorageUsedBytes(v int64) *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.AddSoraStorageUsedBytes(v)
+ })
+}
+
+// UpdateSoraStorageUsedBytes sets the "sora_storage_used_bytes" field to the value that was provided on create.
+func (u *UserUpsertOne) UpdateSoraStorageUsedBytes() *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.UpdateSoraStorageUsedBytes()
+ })
+}
+
// Exec executes the query.
func (u *UserUpsertOne) Exec(ctx context.Context) error {
if len(u.create.conflict) == 0 {
@@ -1646,6 +1774,48 @@ func (u *UserUpsertBulk) ClearTotpEnabledAt() *UserUpsertBulk {
})
}
+// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
+func (u *UserUpsertBulk) SetSoraStorageQuotaBytes(v int64) *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.SetSoraStorageQuotaBytes(v)
+ })
+}
+
+// AddSoraStorageQuotaBytes adds v to the "sora_storage_quota_bytes" field.
+func (u *UserUpsertBulk) AddSoraStorageQuotaBytes(v int64) *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.AddSoraStorageQuotaBytes(v)
+ })
+}
+
+// UpdateSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field to the value that was provided on create.
+func (u *UserUpsertBulk) UpdateSoraStorageQuotaBytes() *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.UpdateSoraStorageQuotaBytes()
+ })
+}
+
+// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field.
+func (u *UserUpsertBulk) SetSoraStorageUsedBytes(v int64) *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.SetSoraStorageUsedBytes(v)
+ })
+}
+
+// AddSoraStorageUsedBytes adds v to the "sora_storage_used_bytes" field.
+func (u *UserUpsertBulk) AddSoraStorageUsedBytes(v int64) *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.AddSoraStorageUsedBytes(v)
+ })
+}
+
+// UpdateSoraStorageUsedBytes sets the "sora_storage_used_bytes" field to the value that was provided on create.
+func (u *UserUpsertBulk) UpdateSoraStorageUsedBytes() *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.UpdateSoraStorageUsedBytes()
+ })
+}
+
// Exec executes the query.
func (u *UserUpsertBulk) Exec(ctx context.Context) error {
if u.create.err != nil {
diff --git a/backend/ent/user_update.go b/backend/ent/user_update.go
index 80222c92..f71f0cad 100644
--- a/backend/ent/user_update.go
+++ b/backend/ent/user_update.go
@@ -242,6 +242,48 @@ func (_u *UserUpdate) ClearTotpEnabledAt() *UserUpdate {
return _u
}
+// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
+func (_u *UserUpdate) SetSoraStorageQuotaBytes(v int64) *UserUpdate {
+ _u.mutation.ResetSoraStorageQuotaBytes()
+ _u.mutation.SetSoraStorageQuotaBytes(v)
+ return _u
+}
+
+// SetNillableSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field if the given value is not nil.
+func (_u *UserUpdate) SetNillableSoraStorageQuotaBytes(v *int64) *UserUpdate {
+ if v != nil {
+ _u.SetSoraStorageQuotaBytes(*v)
+ }
+ return _u
+}
+
+// AddSoraStorageQuotaBytes adds value to the "sora_storage_quota_bytes" field.
+func (_u *UserUpdate) AddSoraStorageQuotaBytes(v int64) *UserUpdate {
+ _u.mutation.AddSoraStorageQuotaBytes(v)
+ return _u
+}
+
+// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field.
+func (_u *UserUpdate) SetSoraStorageUsedBytes(v int64) *UserUpdate {
+ _u.mutation.ResetSoraStorageUsedBytes()
+ _u.mutation.SetSoraStorageUsedBytes(v)
+ return _u
+}
+
+// SetNillableSoraStorageUsedBytes sets the "sora_storage_used_bytes" field if the given value is not nil.
+func (_u *UserUpdate) SetNillableSoraStorageUsedBytes(v *int64) *UserUpdate {
+ if v != nil {
+ _u.SetSoraStorageUsedBytes(*v)
+ }
+ return _u
+}
+
+// AddSoraStorageUsedBytes adds value to the "sora_storage_used_bytes" field.
+func (_u *UserUpdate) AddSoraStorageUsedBytes(v int64) *UserUpdate {
+ _u.mutation.AddSoraStorageUsedBytes(v)
+ return _u
+}
+
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_u *UserUpdate) AddAPIKeyIDs(ids ...int64) *UserUpdate {
_u.mutation.AddAPIKeyIDs(ids...)
@@ -709,6 +751,18 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if _u.mutation.TotpEnabledAtCleared() {
_spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime)
}
+ if value, ok := _u.mutation.SoraStorageQuotaBytes(); ok {
+ _spec.SetField(user.FieldSoraStorageQuotaBytes, field.TypeInt64, value)
+ }
+ if value, ok := _u.mutation.AddedSoraStorageQuotaBytes(); ok {
+ _spec.AddField(user.FieldSoraStorageQuotaBytes, field.TypeInt64, value)
+ }
+ if value, ok := _u.mutation.SoraStorageUsedBytes(); ok {
+ _spec.SetField(user.FieldSoraStorageUsedBytes, field.TypeInt64, value)
+ }
+ if value, ok := _u.mutation.AddedSoraStorageUsedBytes(); ok {
+ _spec.AddField(user.FieldSoraStorageUsedBytes, field.TypeInt64, value)
+ }
if _u.mutation.APIKeysCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
@@ -1352,6 +1406,48 @@ func (_u *UserUpdateOne) ClearTotpEnabledAt() *UserUpdateOne {
return _u
}
+// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
+func (_u *UserUpdateOne) SetSoraStorageQuotaBytes(v int64) *UserUpdateOne {
+ _u.mutation.ResetSoraStorageQuotaBytes()
+ _u.mutation.SetSoraStorageQuotaBytes(v)
+ return _u
+}
+
+// SetNillableSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field if the given value is not nil.
+func (_u *UserUpdateOne) SetNillableSoraStorageQuotaBytes(v *int64) *UserUpdateOne {
+ if v != nil {
+ _u.SetSoraStorageQuotaBytes(*v)
+ }
+ return _u
+}
+
+// AddSoraStorageQuotaBytes adds value to the "sora_storage_quota_bytes" field.
+func (_u *UserUpdateOne) AddSoraStorageQuotaBytes(v int64) *UserUpdateOne {
+ _u.mutation.AddSoraStorageQuotaBytes(v)
+ return _u
+}
+
+// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field.
+func (_u *UserUpdateOne) SetSoraStorageUsedBytes(v int64) *UserUpdateOne {
+ _u.mutation.ResetSoraStorageUsedBytes()
+ _u.mutation.SetSoraStorageUsedBytes(v)
+ return _u
+}
+
+// SetNillableSoraStorageUsedBytes sets the "sora_storage_used_bytes" field if the given value is not nil.
+func (_u *UserUpdateOne) SetNillableSoraStorageUsedBytes(v *int64) *UserUpdateOne {
+ if v != nil {
+ _u.SetSoraStorageUsedBytes(*v)
+ }
+ return _u
+}
+
+// AddSoraStorageUsedBytes adds value to the "sora_storage_used_bytes" field.
+func (_u *UserUpdateOne) AddSoraStorageUsedBytes(v int64) *UserUpdateOne {
+ _u.mutation.AddSoraStorageUsedBytes(v)
+ return _u
+}
+
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_u *UserUpdateOne) AddAPIKeyIDs(ids ...int64) *UserUpdateOne {
_u.mutation.AddAPIKeyIDs(ids...)
@@ -1849,6 +1945,18 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) {
if _u.mutation.TotpEnabledAtCleared() {
_spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime)
}
+ if value, ok := _u.mutation.SoraStorageQuotaBytes(); ok {
+ _spec.SetField(user.FieldSoraStorageQuotaBytes, field.TypeInt64, value)
+ }
+ if value, ok := _u.mutation.AddedSoraStorageQuotaBytes(); ok {
+ _spec.AddField(user.FieldSoraStorageQuotaBytes, field.TypeInt64, value)
+ }
+ if value, ok := _u.mutation.SoraStorageUsedBytes(); ok {
+ _spec.SetField(user.FieldSoraStorageUsedBytes, field.TypeInt64, value)
+ }
+ if value, ok := _u.mutation.AddedSoraStorageUsedBytes(); ok {
+ _spec.AddField(user.FieldSoraStorageUsedBytes, field.TypeInt64, value)
+ }
if _u.mutation.APIKeysCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
diff --git a/backend/go.mod b/backend/go.mod
index 0adddadf..a34c9fff 100644
--- a/backend/go.mod
+++ b/backend/go.mod
@@ -7,7 +7,11 @@ require (
github.com/DATA-DOG/go-sqlmock v1.5.2
github.com/DouDOU-start/go-sora2api v1.1.0
github.com/alitto/pond/v2 v2.6.2
+ github.com/aws/aws-sdk-go-v2/config v1.32.10
+ github.com/aws/aws-sdk-go-v2/credentials v1.19.10
+ github.com/aws/aws-sdk-go-v2/service/s3 v1.96.2
github.com/cespare/xxhash/v2 v2.3.0
+ github.com/coder/websocket v1.8.14
github.com/dgraph-io/ristretto v0.2.0
github.com/gin-gonic/gin v1.9.1
github.com/golang-jwt/jwt/v5 v5.2.2
@@ -34,6 +38,8 @@ require (
golang.org/x/net v0.49.0
golang.org/x/sync v0.19.0
golang.org/x/term v0.40.0
+ google.golang.org/grpc v1.75.1
+ google.golang.org/protobuf v1.36.10
gopkg.in/natefinch/lumberjack.v2 v2.2.1
gopkg.in/yaml.v3 v3.0.1
modernc.org/sqlite v1.44.3
@@ -47,6 +53,22 @@ require (
github.com/agext/levenshtein v1.2.3 // indirect
github.com/andybalholm/brotli v1.2.0 // indirect
github.com/apparentlymart/go-textseg/v15 v15.0.0 // indirect
+ github.com/aws/aws-sdk-go-v2 v1.41.2 // indirect
+ github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5 // indirect
+ github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.18 // indirect
+ github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.18 // indirect
+ github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.18 // indirect
+ github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4 // indirect
+ github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.18 // indirect
+ github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.5 // indirect
+ github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.10 // indirect
+ github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.18 // indirect
+ github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.18 // indirect
+ github.com/aws/aws-sdk-go-v2/service/signin v1.0.6 // indirect
+ github.com/aws/aws-sdk-go-v2/service/sso v1.30.11 // indirect
+ github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.15 // indirect
+ github.com/aws/aws-sdk-go-v2/service/sts v1.41.7 // indirect
+ github.com/aws/smithy-go v1.24.1 // indirect
github.com/bdandy/go-errors v1.2.2 // indirect
github.com/bdandy/go-socks4 v1.2.3 // indirect
github.com/bmatcuk/doublestar v1.3.4 // indirect
@@ -87,6 +109,7 @@ require (
github.com/goccy/go-json v0.10.2 // indirect
github.com/google/go-cmp v0.7.0 // indirect
github.com/google/go-querystring v1.1.0 // indirect
+ github.com/google/subcommands v1.2.0 // indirect
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 // indirect
github.com/hashicorp/hcl v1.0.0 // indirect
github.com/hashicorp/hcl/v2 v2.18.1 // indirect
@@ -146,7 +169,6 @@ require (
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 // indirect
go.opentelemetry.io/otel v1.37.0 // indirect
go.opentelemetry.io/otel/metric v1.37.0 // indirect
- go.opentelemetry.io/otel/sdk v1.37.0 // indirect
go.opentelemetry.io/otel/trace v1.37.0 // indirect
go.uber.org/atomic v1.10.0 // indirect
go.uber.org/automaxprocs v1.6.0 // indirect
@@ -156,6 +178,8 @@ require (
golang.org/x/mod v0.32.0 // indirect
golang.org/x/sys v0.41.0 // indirect
golang.org/x/text v0.34.0 // indirect
+ golang.org/x/tools v0.41.0 // indirect
+ google.golang.org/genproto/googleapis/rpc v0.0.0-20250929231259-57b25ae835d4 // indirect
google.golang.org/grpc v1.75.1 // indirect
google.golang.org/protobuf v1.36.10 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
diff --git a/backend/go.sum b/backend/go.sum
index efe6c145..32e389a7 100644
--- a/backend/go.sum
+++ b/backend/go.sum
@@ -22,6 +22,44 @@ github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwTo
github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY=
github.com/apparentlymart/go-textseg/v15 v15.0.0 h1:uYvfpb3DyLSCGWnctWKGj857c6ew1u1fNQOlOtuGxQY=
github.com/apparentlymart/go-textseg/v15 v15.0.0/go.mod h1:K8XmNZdhEBkdlyDdvbmmsvpAG721bKi0joRfFdHIWJ4=
+github.com/aws/aws-sdk-go-v2 v1.41.2 h1:LuT2rzqNQsauaGkPK/7813XxcZ3o3yePY0Iy891T2ls=
+github.com/aws/aws-sdk-go-v2 v1.41.2/go.mod h1:IvvlAZQXvTXznUPfRVfryiG1fbzE2NGK6m9u39YQ+S4=
+github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5 h1:zWFmPmgw4sveAYi1mRqG+E/g0461cJ5M4bJ8/nc6d3Q=
+github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5/go.mod h1:nVUlMLVV8ycXSb7mSkcNu9e3v/1TJq2RTlrPwhYWr5c=
+github.com/aws/aws-sdk-go-v2/config v1.32.10 h1:9DMthfO6XWZYLfzZglAgW5Fyou2nRI5CuV44sTedKBI=
+github.com/aws/aws-sdk-go-v2/config v1.32.10/go.mod h1:2rUIOnA2JaiqYmSKYmRJlcMWy6qTj1vuRFscppSBMcw=
+github.com/aws/aws-sdk-go-v2/credentials v1.19.10 h1:EEhmEUFCE1Yhl7vDhNOI5OCL/iKMdkkYFTRpZXNw7m8=
+github.com/aws/aws-sdk-go-v2/credentials v1.19.10/go.mod h1:RnnlFCAlxQCkN2Q379B67USkBMu1PipEEiibzYN5UTE=
+github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.18 h1:Ii4s+Sq3yDfaMLpjrJsqD6SmG/Wq/P5L/hw2qa78UAY=
+github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.18/go.mod h1:6x81qnY++ovptLE6nWQeWrpXxbnlIex+4H4eYYGcqfc=
+github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.18 h1:F43zk1vemYIqPAwhjTjYIz0irU2EY7sOb/F5eJ3HuyM=
+github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.18/go.mod h1:w1jdlZXrGKaJcNoL+Nnrj+k5wlpGXqnNrKoP22HvAug=
+github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.18 h1:xCeWVjj0ki0l3nruoyP2slHsGArMxeiiaoPN5QZH6YQ=
+github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.18/go.mod h1:r/eLGuGCBw6l36ZRWiw6PaZwPXb6YOj+i/7MizNl5/k=
+github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4 h1:WKuaxf++XKWlHWu9ECbMlha8WOEGm0OUEZqm4K/Gcfk=
+github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4/go.mod h1:ZWy7j6v1vWGmPReu0iSGvRiise4YI5SkR3OHKTZ6Wuc=
+github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.18 h1:eZioDaZGJ0tMM4gzmkNIO2aAoQd+je7Ug7TkvAzlmkU=
+github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.18/go.mod h1:CCXwUKAJdoWr6/NcxZ+zsiPr6oH/Q5aTooRGYieAyj4=
+github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.5 h1:CeY9LUdur+Dxoeldqoun6y4WtJ3RQtzk0JMP2gfUay0=
+github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.5/go.mod h1:AZLZf2fMaahW5s/wMRciu1sYbdsikT/UHwbUjOdEVTc=
+github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.10 h1:fJvQ5mIBVfKtiyx0AHY6HeWcRX5LGANLpq8SVR+Uazs=
+github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.10/go.mod h1:Kzm5e6OmNH8VMkgK9t+ry5jEih4Y8whqs+1hrkxim1I=
+github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.18 h1:LTRCYFlnnKFlKsyIQxKhJuDuA3ZkrDQMRYm6rXiHlLY=
+github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.18/go.mod h1:XhwkgGG6bHSd00nO/mexWTcTjgd6PjuvWQMqSn2UaEk=
+github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.18 h1:/A/xDuZAVD2BpsS2fftFRo/NoEKQJ8YTnJDEHBy2Gtg=
+github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.18/go.mod h1:hWe9b4f+djUQGmyiGEeOnZv69dtMSgpDRIvNMvuvzvY=
+github.com/aws/aws-sdk-go-v2/service/s3 v1.96.2 h1:M1A9AjcFwlxTLuf0Faj88L8Iqw0n/AJHjpZTQzMMsSc=
+github.com/aws/aws-sdk-go-v2/service/s3 v1.96.2/go.mod h1:KsdTV6Q9WKUZm2mNJnUFmIoXfZux91M3sr/a4REX8e0=
+github.com/aws/aws-sdk-go-v2/service/signin v1.0.6 h1:MzORe+J94I+hYu2a6XmV5yC9huoTv8NRcCrUNedDypQ=
+github.com/aws/aws-sdk-go-v2/service/signin v1.0.6/go.mod h1:hXzcHLARD7GeWnifd8j9RWqtfIgxj4/cAtIVIK7hg8g=
+github.com/aws/aws-sdk-go-v2/service/sso v1.30.11 h1:7oGD8KPfBOJGXiCoRKrrrQkbvCp8N++u36hrLMPey6o=
+github.com/aws/aws-sdk-go-v2/service/sso v1.30.11/go.mod h1:0DO9B5EUJQlIDif+XJRWCljZRKsAFKh3gpFz7UnDtOo=
+github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.15 h1:edCcNp9eGIUDUCrzoCu1jWAXLGFIizeqkdkKgRlJwWc=
+github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.15/go.mod h1:lyRQKED9xWfgkYC/wmmYfv7iVIM68Z5OQ88ZdcV1QbU=
+github.com/aws/aws-sdk-go-v2/service/sts v1.41.7 h1:NITQpgo9A5NrDZ57uOWj+abvXSb83BbyggcUBVksN7c=
+github.com/aws/aws-sdk-go-v2/service/sts v1.41.7/go.mod h1:sks5UWBhEuWYDPdwlnRFn1w7xWdH29Jcpe+/PJQefEs=
+github.com/aws/smithy-go v1.24.1 h1:VbyeNfmYkWoxMVpGUAbQumkODcYmfMRfZ8yQiH30SK0=
+github.com/aws/smithy-go v1.24.1/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0=
github.com/bdandy/go-errors v1.2.2 h1:WdFv/oukjTJCLa79UfkGmwX7ZxONAihKu4V0mLIs11Q=
github.com/bdandy/go-errors v1.2.2/go.mod h1:NkYHl4Fey9oRRdbB1CoC6e84tuqQHiqrOcZpqFEkBxM=
github.com/bdandy/go-socks4 v1.2.3 h1:Q6Y2heY1GRjCtHbmlKfnwrKVU/k81LS8mRGLRlmDlic=
@@ -56,6 +94,12 @@ github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XL
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
+github.com/clipperhouse/stringish v0.1.1 h1:+NSqMOr3GR6k1FdRhhnXrLfztGzuG+VuFDfatpWHKCs=
+github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEXNWYXQgCt4hdOzA=
+github.com/clipperhouse/uax29/v2 v2.5.0 h1:x7T0T4eTHDONxFJsL94uKNKPHrclyFI0lm7+w94cO8U=
+github.com/clipperhouse/uax29/v2 v2.5.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g=
+github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g=
+github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg=
github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG8PI=
github.com/containerd/errdefs v1.0.0/go.mod h1:+YBYIdtsnF4Iw6nWZhJcqGSg/dwvV7tyJ/kCkyJ2k+M=
github.com/containerd/errdefs/pkg v0.3.0 h1:9IKJ06FvyNlexW690DXuQNx2KA2cUJXx151Xdx3ZPPE=
@@ -127,6 +171,8 @@ github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
+github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
+github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
@@ -136,6 +182,8 @@ github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
+github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE=
+github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/wire v0.7.0 h1:JxUKI6+CVBgCO2WToKy/nQk0sS+amI9z9EjVmdaocj4=
@@ -190,6 +238,8 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
+github.com/mattn/go-runewidth v0.0.19 h1:v++JhqYnZuu5jSKrk9RbgF5v4CGUjqRfBm05byFGLdw=
+github.com/mattn/go-runewidth v0.0.19/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM=
github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI=
@@ -223,6 +273,8 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
+github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
+github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
@@ -274,6 +326,8 @@ github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY=
github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0=
github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
+github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I=
+github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ=
@@ -344,6 +398,8 @@ go.opentelemetry.io/otel/metric v1.37.0 h1:mvwbQS5m0tbmqML4NqK+e3aDiO02vsf/Wgbsd
go.opentelemetry.io/otel/metric v1.37.0/go.mod h1:04wGrZurHYKOc+RKeye86GwKiTb9FKm1WHtO+4EVr2E=
go.opentelemetry.io/otel/sdk v1.37.0 h1:ItB0QUqnjesGRvNcmAcU0LyvkVyGJ2xftD29bWdDvKI=
go.opentelemetry.io/otel/sdk v1.37.0/go.mod h1:VredYzxUvuo2q3WRcDnKDjbdvmO0sCzOvVAiY+yUkAg=
+go.opentelemetry.io/otel/sdk/metric v1.37.0 h1:90lI228XrB9jCMuSdA0673aubgRobVZFhbjxHHspCPc=
+go.opentelemetry.io/otel/sdk/metric v1.37.0/go.mod h1:cNen4ZWfiD37l5NhS+Keb5RXVWZWpRE+9WyVCpbo5ps=
go.opentelemetry.io/otel/trace v1.37.0 h1:HLdcFNbRQBE2imdSEgm/kwqmQj1Or1l/7bW6mxVK7z4=
go.opentelemetry.io/otel/trace v1.37.0/go.mod h1:TlgrlQ+PtQO5XFerSPUYG0JSgGyryXewPGyayAWSBS0=
go.opentelemetry.io/proto/otlp v1.3.1 h1:TrMUixzpM0yuc/znrFTP9MMRh8trP93mkCiDVeXrui0=
@@ -399,6 +455,8 @@ golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGm
golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc=
golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
+gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
+gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E=
google.golang.org/genproto v0.0.0-20231106174013-bbf56f31fb17 h1:wpZ8pe2x1Q3f2KyT5f8oP/fa9rHAKgFPr/HZdNuS+PQ=
google.golang.org/genproto/googleapis/api v0.0.0-20250929231259-57b25ae835d4 h1:8XJ4pajGwOlasW+L13MnEGA8W4115jJySQtVfS2/IBU=
google.golang.org/genproto/googleapis/api v0.0.0-20250929231259-57b25ae835d4/go.mod h1:NnuHhy+bxcg30o7FnVAZbXsPHUDQ9qKWAQKCD7VxFtk=
diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go
index de3251b6..4f6fea37 100644
--- a/backend/internal/config/config.go
+++ b/backend/internal/config/config.go
@@ -364,6 +364,8 @@ type GatewayConfig struct {
// OpenAIPassthroughAllowTimeoutHeaders: OpenAI 透传模式是否放行客户端超时头
// 关闭(默认)可避免 x-stainless-timeout 等头导致上游提前断流。
OpenAIPassthroughAllowTimeoutHeaders bool `mapstructure:"openai_passthrough_allow_timeout_headers"`
+ // OpenAIWS: OpenAI Responses WebSocket 配置(默认开启,可按需回滚到 HTTP)
+ OpenAIWS GatewayOpenAIWSConfig `mapstructure:"openai_ws"`
// HTTP 上游连接池配置(性能优化:支持高并发场景调优)
// MaxIdleConns: 所有主机的最大空闲连接总数
@@ -450,6 +452,101 @@ type GatewayConfig struct {
ModelsListCacheTTLSeconds int `mapstructure:"models_list_cache_ttl_seconds"`
}
+// GatewayOpenAIWSConfig OpenAI Responses WebSocket 配置。
+// 注意:默认全局开启;如需回滚可使用 force_http 或关闭 enabled。
+type GatewayOpenAIWSConfig struct {
+ // ModeRouterV2Enabled: 新版 WS mode 路由开关(默认 false;关闭时保持 legacy 行为)
+ ModeRouterV2Enabled bool `mapstructure:"mode_router_v2_enabled"`
+ // IngressModeDefault: ingress 默认模式(off/shared/dedicated)
+ IngressModeDefault string `mapstructure:"ingress_mode_default"`
+ // Enabled: 全局总开关(默认 true)
+ Enabled bool `mapstructure:"enabled"`
+ // OAuthEnabled: 是否允许 OpenAI OAuth 账号使用 WS
+ OAuthEnabled bool `mapstructure:"oauth_enabled"`
+ // APIKeyEnabled: 是否允许 OpenAI API Key 账号使用 WS
+ APIKeyEnabled bool `mapstructure:"apikey_enabled"`
+ // ForceHTTP: 全局强制 HTTP(用于紧急回滚)
+ ForceHTTP bool `mapstructure:"force_http"`
+ // AllowStoreRecovery: 允许在 WSv2 下按策略恢复 store=true(默认 false)
+ AllowStoreRecovery bool `mapstructure:"allow_store_recovery"`
+ // IngressPreviousResponseRecoveryEnabled: ingress 模式收到 previous_response_not_found 时,是否允许自动去掉 previous_response_id 重试一次(默认 true)
+ IngressPreviousResponseRecoveryEnabled bool `mapstructure:"ingress_previous_response_recovery_enabled"`
+ // StoreDisabledConnMode: store=false 且无可复用会话连接时的建连策略(strict/adaptive/off)
+ // - strict: 强制新建连接(隔离优先)
+ // - adaptive: 仅在高风险失败后强制新建连接(性能与隔离折中)
+ // - off: 不强制新建连接(复用优先)
+ StoreDisabledConnMode string `mapstructure:"store_disabled_conn_mode"`
+ // StoreDisabledForceNewConn: store=false 且无可复用粘连连接时是否强制新建连接(默认 true,保障会话隔离)
+ // 兼容旧配置;当 StoreDisabledConnMode 为空时才生效。
+ StoreDisabledForceNewConn bool `mapstructure:"store_disabled_force_new_conn"`
+ // PrewarmGenerateEnabled: 是否启用 WSv2 generate=false 预热(默认 false)
+ PrewarmGenerateEnabled bool `mapstructure:"prewarm_generate_enabled"`
+
+ // Feature 开关:v2 优先于 v1
+ ResponsesWebsockets bool `mapstructure:"responses_websockets"`
+ ResponsesWebsocketsV2 bool `mapstructure:"responses_websockets_v2"`
+
+ // 连接池参数
+ MaxConnsPerAccount int `mapstructure:"max_conns_per_account"`
+ MinIdlePerAccount int `mapstructure:"min_idle_per_account"`
+ MaxIdlePerAccount int `mapstructure:"max_idle_per_account"`
+ // DynamicMaxConnsByAccountConcurrencyEnabled: 是否按账号并发动态计算连接池上限
+ DynamicMaxConnsByAccountConcurrencyEnabled bool `mapstructure:"dynamic_max_conns_by_account_concurrency_enabled"`
+ // OAuthMaxConnsFactor: OAuth 账号连接池系数(effective=ceil(concurrency*factor))
+ OAuthMaxConnsFactor float64 `mapstructure:"oauth_max_conns_factor"`
+ // APIKeyMaxConnsFactor: API Key 账号连接池系数(effective=ceil(concurrency*factor))
+ APIKeyMaxConnsFactor float64 `mapstructure:"apikey_max_conns_factor"`
+ DialTimeoutSeconds int `mapstructure:"dial_timeout_seconds"`
+ ReadTimeoutSeconds int `mapstructure:"read_timeout_seconds"`
+ WriteTimeoutSeconds int `mapstructure:"write_timeout_seconds"`
+ PoolTargetUtilization float64 `mapstructure:"pool_target_utilization"`
+ QueueLimitPerConn int `mapstructure:"queue_limit_per_conn"`
+ // EventFlushBatchSize: WS 流式写出批量 flush 阈值(事件条数)
+ EventFlushBatchSize int `mapstructure:"event_flush_batch_size"`
+ // EventFlushIntervalMS: WS 流式写出最大等待时间(毫秒);0 表示仅按 batch 触发
+ EventFlushIntervalMS int `mapstructure:"event_flush_interval_ms"`
+ // PrewarmCooldownMS: 连接池预热触发冷却时间(毫秒)
+ PrewarmCooldownMS int `mapstructure:"prewarm_cooldown_ms"`
+ // FallbackCooldownSeconds: WS 回退冷却窗口,避免 WS/HTTP 抖动;0 表示关闭冷却
+ FallbackCooldownSeconds int `mapstructure:"fallback_cooldown_seconds"`
+ // RetryBackoffInitialMS: WS 重试初始退避(毫秒);<=0 表示关闭退避
+ RetryBackoffInitialMS int `mapstructure:"retry_backoff_initial_ms"`
+ // RetryBackoffMaxMS: WS 重试最大退避(毫秒)
+ RetryBackoffMaxMS int `mapstructure:"retry_backoff_max_ms"`
+ // RetryJitterRatio: WS 重试退避抖动比例(0-1)
+ RetryJitterRatio float64 `mapstructure:"retry_jitter_ratio"`
+ // RetryTotalBudgetMS: WS 单次请求重试总预算(毫秒);0 表示关闭预算限制
+ RetryTotalBudgetMS int `mapstructure:"retry_total_budget_ms"`
+ // PayloadLogSampleRate: payload_schema 日志采样率(0-1)
+ PayloadLogSampleRate float64 `mapstructure:"payload_log_sample_rate"`
+
+ // 账号调度与粘连参数
+ LBTopK int `mapstructure:"lb_top_k"`
+ // StickySessionTTLSeconds: session_hash -> account_id 粘连 TTL
+ StickySessionTTLSeconds int `mapstructure:"sticky_session_ttl_seconds"`
+ // SessionHashReadOldFallback: 会话哈希迁移期是否允许“新 key 未命中时回退读旧 SHA-256 key”
+ SessionHashReadOldFallback bool `mapstructure:"session_hash_read_old_fallback"`
+ // SessionHashDualWriteOld: 会话哈希迁移期是否双写旧 SHA-256 key(短 TTL)
+ SessionHashDualWriteOld bool `mapstructure:"session_hash_dual_write_old"`
+ // MetadataBridgeEnabled: RequestMetadata 迁移期是否保留旧 ctxkey.* 兼容桥接
+ MetadataBridgeEnabled bool `mapstructure:"metadata_bridge_enabled"`
+ // StickyResponseIDTTLSeconds: response_id -> account_id 粘连 TTL
+ StickyResponseIDTTLSeconds int `mapstructure:"sticky_response_id_ttl_seconds"`
+ // StickyPreviousResponseTTLSeconds: 兼容旧键(当新键未设置时回退)
+ StickyPreviousResponseTTLSeconds int `mapstructure:"sticky_previous_response_ttl_seconds"`
+
+ SchedulerScoreWeights GatewayOpenAIWSSchedulerScoreWeights `mapstructure:"scheduler_score_weights"`
+}
+
+// GatewayOpenAIWSSchedulerScoreWeights 账号调度打分权重。
+type GatewayOpenAIWSSchedulerScoreWeights struct {
+ Priority float64 `mapstructure:"priority"`
+ Load float64 `mapstructure:"load"`
+ Queue float64 `mapstructure:"queue"`
+ ErrorRate float64 `mapstructure:"error_rate"`
+ TTFT float64 `mapstructure:"ttft"`
+}
+
// GatewayUsageRecordConfig 使用量记录异步队列配置
type GatewayUsageRecordConfig struct {
// WorkerCount: worker 初始数量(自动扩缩容开启时作为初始并发上限)
@@ -886,6 +983,12 @@ func load(allowMissingJWTSecret bool) (*Config, error) {
cfg.Log.StacktraceLevel = strings.ToLower(strings.TrimSpace(cfg.Log.StacktraceLevel))
cfg.Log.Output.FilePath = strings.TrimSpace(cfg.Log.Output.FilePath)
+ // 兼容旧键 gateway.openai_ws.sticky_previous_response_ttl_seconds。
+ // 新键未配置(<=0)时回退旧键;新键优先。
+ if cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds <= 0 && cfg.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds > 0 {
+ cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = cfg.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds
+ }
+
// Auto-generate TOTP encryption key if not set (32 bytes = 64 hex chars for AES-256)
cfg.Totp.EncryptionKey = strings.TrimSpace(cfg.Totp.EncryptionKey)
if cfg.Totp.EncryptionKey == "" {
@@ -945,7 +1048,7 @@ func setDefaults() {
viper.SetDefault("server.read_header_timeout", 30) // 30秒读取请求头
viper.SetDefault("server.idle_timeout", 120) // 120秒空闲超时
viper.SetDefault("server.trusted_proxies", []string{})
- viper.SetDefault("server.max_request_body_size", int64(100*1024*1024))
+ viper.SetDefault("server.max_request_body_size", int64(256*1024*1024))
// H2C 默认配置
viper.SetDefault("server.h2c.enabled", false)
viper.SetDefault("server.h2c.max_concurrent_streams", uint32(50)) // 50 个并发流
@@ -1088,9 +1191,9 @@ func setDefaults() {
// RateLimit
viper.SetDefault("rate_limit.overload_cooldown_minutes", 10)
- // Pricing - 从 model-price-repo 同步模型定价和上下文窗口数据的配置
- viper.SetDefault("pricing.remote_url", "https://github.com/Wei-Shaw/model-price-repo/raw/refs/heads/main/model_prices_and_context_window.json")
- viper.SetDefault("pricing.hash_url", "https://github.com/Wei-Shaw/model-price-repo/raw/refs/heads/main/model_prices_and_context_window.sha256")
+ // Pricing - 从 model-price-repo 同步模型定价和上下文窗口数据(固定到 commit,避免分支漂移)
+ viper.SetDefault("pricing.remote_url", "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/c7947e9871687e664180bc971d4837f1fc2784a9/model_prices_and_context_window.json")
+ viper.SetDefault("pricing.hash_url", "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/c7947e9871687e664180bc971d4837f1fc2784a9/model_prices_and_context_window.sha256")
viper.SetDefault("pricing.data_dir", "./data")
viper.SetDefault("pricing.fallback_file", "./resources/model-pricing/model_prices_and_context_window.json")
viper.SetDefault("pricing.update_interval_hours", 24)
@@ -1157,9 +1260,55 @@ func setDefaults() {
viper.SetDefault("gateway.max_account_switches_gemini", 3)
viper.SetDefault("gateway.force_codex_cli", false)
viper.SetDefault("gateway.openai_passthrough_allow_timeout_headers", false)
+ // OpenAI Responses WebSocket(默认开启;可通过 force_http 紧急回滚)
+ viper.SetDefault("gateway.openai_ws.enabled", true)
+ viper.SetDefault("gateway.openai_ws.mode_router_v2_enabled", false)
+ viper.SetDefault("gateway.openai_ws.ingress_mode_default", "shared")
+ viper.SetDefault("gateway.openai_ws.oauth_enabled", true)
+ viper.SetDefault("gateway.openai_ws.apikey_enabled", true)
+ viper.SetDefault("gateway.openai_ws.force_http", false)
+ viper.SetDefault("gateway.openai_ws.allow_store_recovery", false)
+ viper.SetDefault("gateway.openai_ws.ingress_previous_response_recovery_enabled", true)
+ viper.SetDefault("gateway.openai_ws.store_disabled_conn_mode", "strict")
+ viper.SetDefault("gateway.openai_ws.store_disabled_force_new_conn", true)
+ viper.SetDefault("gateway.openai_ws.prewarm_generate_enabled", false)
+ viper.SetDefault("gateway.openai_ws.responses_websockets", false)
+ viper.SetDefault("gateway.openai_ws.responses_websockets_v2", true)
+ viper.SetDefault("gateway.openai_ws.max_conns_per_account", 128)
+ viper.SetDefault("gateway.openai_ws.min_idle_per_account", 4)
+ viper.SetDefault("gateway.openai_ws.max_idle_per_account", 12)
+ viper.SetDefault("gateway.openai_ws.dynamic_max_conns_by_account_concurrency_enabled", true)
+ viper.SetDefault("gateway.openai_ws.oauth_max_conns_factor", 1.0)
+ viper.SetDefault("gateway.openai_ws.apikey_max_conns_factor", 1.0)
+ viper.SetDefault("gateway.openai_ws.dial_timeout_seconds", 10)
+ viper.SetDefault("gateway.openai_ws.read_timeout_seconds", 900)
+ viper.SetDefault("gateway.openai_ws.write_timeout_seconds", 120)
+ viper.SetDefault("gateway.openai_ws.pool_target_utilization", 0.7)
+ viper.SetDefault("gateway.openai_ws.queue_limit_per_conn", 64)
+ viper.SetDefault("gateway.openai_ws.event_flush_batch_size", 1)
+ viper.SetDefault("gateway.openai_ws.event_flush_interval_ms", 10)
+ viper.SetDefault("gateway.openai_ws.prewarm_cooldown_ms", 300)
+ viper.SetDefault("gateway.openai_ws.fallback_cooldown_seconds", 30)
+ viper.SetDefault("gateway.openai_ws.retry_backoff_initial_ms", 120)
+ viper.SetDefault("gateway.openai_ws.retry_backoff_max_ms", 2000)
+ viper.SetDefault("gateway.openai_ws.retry_jitter_ratio", 0.2)
+ viper.SetDefault("gateway.openai_ws.retry_total_budget_ms", 5000)
+ viper.SetDefault("gateway.openai_ws.payload_log_sample_rate", 0.2)
+ viper.SetDefault("gateway.openai_ws.lb_top_k", 7)
+ viper.SetDefault("gateway.openai_ws.sticky_session_ttl_seconds", 3600)
+ viper.SetDefault("gateway.openai_ws.session_hash_read_old_fallback", true)
+ viper.SetDefault("gateway.openai_ws.session_hash_dual_write_old", true)
+ viper.SetDefault("gateway.openai_ws.metadata_bridge_enabled", true)
+ viper.SetDefault("gateway.openai_ws.sticky_response_id_ttl_seconds", 3600)
+ viper.SetDefault("gateway.openai_ws.sticky_previous_response_ttl_seconds", 3600)
+ viper.SetDefault("gateway.openai_ws.scheduler_score_weights.priority", 1.0)
+ viper.SetDefault("gateway.openai_ws.scheduler_score_weights.load", 1.0)
+ viper.SetDefault("gateway.openai_ws.scheduler_score_weights.queue", 0.7)
+ viper.SetDefault("gateway.openai_ws.scheduler_score_weights.error_rate", 0.8)
+ viper.SetDefault("gateway.openai_ws.scheduler_score_weights.ttft", 0.5)
viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1)
viper.SetDefault("gateway.antigravity_extra_retries", 10)
- viper.SetDefault("gateway.max_body_size", int64(100*1024*1024))
+ viper.SetDefault("gateway.max_body_size", int64(256*1024*1024))
viper.SetDefault("gateway.upstream_response_read_max_bytes", int64(8*1024*1024))
viper.SetDefault("gateway.proxy_probe_response_read_max_bytes", int64(1024*1024))
viper.SetDefault("gateway.gemini_debug_response_headers", false)
@@ -1747,6 +1896,118 @@ func (c *Config) Validate() error {
(c.Gateway.StreamKeepaliveInterval < 5 || c.Gateway.StreamKeepaliveInterval > 30) {
return fmt.Errorf("gateway.stream_keepalive_interval must be 0 or between 5-30 seconds")
}
+ // 兼容旧键 sticky_previous_response_ttl_seconds
+ if c.Gateway.OpenAIWS.StickyResponseIDTTLSeconds <= 0 && c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds > 0 {
+ c.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds
+ }
+ if c.Gateway.OpenAIWS.MaxConnsPerAccount <= 0 {
+ return fmt.Errorf("gateway.openai_ws.max_conns_per_account must be positive")
+ }
+ if c.Gateway.OpenAIWS.MinIdlePerAccount < 0 {
+ return fmt.Errorf("gateway.openai_ws.min_idle_per_account must be non-negative")
+ }
+ if c.Gateway.OpenAIWS.MaxIdlePerAccount < 0 {
+ return fmt.Errorf("gateway.openai_ws.max_idle_per_account must be non-negative")
+ }
+ if c.Gateway.OpenAIWS.MinIdlePerAccount > c.Gateway.OpenAIWS.MaxIdlePerAccount {
+ return fmt.Errorf("gateway.openai_ws.min_idle_per_account must be <= max_idle_per_account")
+ }
+ if c.Gateway.OpenAIWS.MaxIdlePerAccount > c.Gateway.OpenAIWS.MaxConnsPerAccount {
+ return fmt.Errorf("gateway.openai_ws.max_idle_per_account must be <= max_conns_per_account")
+ }
+ if c.Gateway.OpenAIWS.OAuthMaxConnsFactor <= 0 {
+ return fmt.Errorf("gateway.openai_ws.oauth_max_conns_factor must be positive")
+ }
+ if c.Gateway.OpenAIWS.APIKeyMaxConnsFactor <= 0 {
+ return fmt.Errorf("gateway.openai_ws.apikey_max_conns_factor must be positive")
+ }
+ if c.Gateway.OpenAIWS.DialTimeoutSeconds <= 0 {
+ return fmt.Errorf("gateway.openai_ws.dial_timeout_seconds must be positive")
+ }
+ if c.Gateway.OpenAIWS.ReadTimeoutSeconds <= 0 {
+ return fmt.Errorf("gateway.openai_ws.read_timeout_seconds must be positive")
+ }
+ if c.Gateway.OpenAIWS.WriteTimeoutSeconds <= 0 {
+ return fmt.Errorf("gateway.openai_ws.write_timeout_seconds must be positive")
+ }
+ if c.Gateway.OpenAIWS.PoolTargetUtilization <= 0 || c.Gateway.OpenAIWS.PoolTargetUtilization > 1 {
+ return fmt.Errorf("gateway.openai_ws.pool_target_utilization must be within (0,1]")
+ }
+ if c.Gateway.OpenAIWS.QueueLimitPerConn <= 0 {
+ return fmt.Errorf("gateway.openai_ws.queue_limit_per_conn must be positive")
+ }
+ if c.Gateway.OpenAIWS.EventFlushBatchSize <= 0 {
+ return fmt.Errorf("gateway.openai_ws.event_flush_batch_size must be positive")
+ }
+ if c.Gateway.OpenAIWS.EventFlushIntervalMS < 0 {
+ return fmt.Errorf("gateway.openai_ws.event_flush_interval_ms must be non-negative")
+ }
+ if c.Gateway.OpenAIWS.PrewarmCooldownMS < 0 {
+ return fmt.Errorf("gateway.openai_ws.prewarm_cooldown_ms must be non-negative")
+ }
+ if c.Gateway.OpenAIWS.FallbackCooldownSeconds < 0 {
+ return fmt.Errorf("gateway.openai_ws.fallback_cooldown_seconds must be non-negative")
+ }
+ if c.Gateway.OpenAIWS.RetryBackoffInitialMS < 0 {
+ return fmt.Errorf("gateway.openai_ws.retry_backoff_initial_ms must be non-negative")
+ }
+ if c.Gateway.OpenAIWS.RetryBackoffMaxMS < 0 {
+ return fmt.Errorf("gateway.openai_ws.retry_backoff_max_ms must be non-negative")
+ }
+ if c.Gateway.OpenAIWS.RetryBackoffInitialMS > 0 && c.Gateway.OpenAIWS.RetryBackoffMaxMS > 0 &&
+ c.Gateway.OpenAIWS.RetryBackoffMaxMS < c.Gateway.OpenAIWS.RetryBackoffInitialMS {
+ return fmt.Errorf("gateway.openai_ws.retry_backoff_max_ms must be >= retry_backoff_initial_ms")
+ }
+ if c.Gateway.OpenAIWS.RetryJitterRatio < 0 || c.Gateway.OpenAIWS.RetryJitterRatio > 1 {
+ return fmt.Errorf("gateway.openai_ws.retry_jitter_ratio must be within [0,1]")
+ }
+ if c.Gateway.OpenAIWS.RetryTotalBudgetMS < 0 {
+ return fmt.Errorf("gateway.openai_ws.retry_total_budget_ms must be non-negative")
+ }
+ if mode := strings.ToLower(strings.TrimSpace(c.Gateway.OpenAIWS.IngressModeDefault)); mode != "" {
+ switch mode {
+ case "off", "shared", "dedicated":
+ default:
+ return fmt.Errorf("gateway.openai_ws.ingress_mode_default must be one of off|shared|dedicated")
+ }
+ }
+ if mode := strings.ToLower(strings.TrimSpace(c.Gateway.OpenAIWS.StoreDisabledConnMode)); mode != "" {
+ switch mode {
+ case "strict", "adaptive", "off":
+ default:
+ return fmt.Errorf("gateway.openai_ws.store_disabled_conn_mode must be one of strict|adaptive|off")
+ }
+ }
+ if c.Gateway.OpenAIWS.PayloadLogSampleRate < 0 || c.Gateway.OpenAIWS.PayloadLogSampleRate > 1 {
+ return fmt.Errorf("gateway.openai_ws.payload_log_sample_rate must be within [0,1]")
+ }
+ if c.Gateway.OpenAIWS.LBTopK <= 0 {
+ return fmt.Errorf("gateway.openai_ws.lb_top_k must be positive")
+ }
+ if c.Gateway.OpenAIWS.StickySessionTTLSeconds <= 0 {
+ return fmt.Errorf("gateway.openai_ws.sticky_session_ttl_seconds must be positive")
+ }
+ if c.Gateway.OpenAIWS.StickyResponseIDTTLSeconds <= 0 {
+ return fmt.Errorf("gateway.openai_ws.sticky_response_id_ttl_seconds must be positive")
+ }
+ if c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds < 0 {
+ return fmt.Errorf("gateway.openai_ws.sticky_previous_response_ttl_seconds must be non-negative")
+ }
+ if c.Gateway.OpenAIWS.SchedulerScoreWeights.Priority < 0 ||
+ c.Gateway.OpenAIWS.SchedulerScoreWeights.Load < 0 ||
+ c.Gateway.OpenAIWS.SchedulerScoreWeights.Queue < 0 ||
+ c.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate < 0 ||
+ c.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT < 0 {
+ return fmt.Errorf("gateway.openai_ws.scheduler_score_weights.* must be non-negative")
+ }
+ weightSum := c.Gateway.OpenAIWS.SchedulerScoreWeights.Priority +
+ c.Gateway.OpenAIWS.SchedulerScoreWeights.Load +
+ c.Gateway.OpenAIWS.SchedulerScoreWeights.Queue +
+ c.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate +
+ c.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT
+ if weightSum <= 0 {
+ return fmt.Errorf("gateway.openai_ws.scheduler_score_weights must not all be zero")
+ }
if c.Gateway.MaxLineSize < 0 {
return fmt.Errorf("gateway.max_line_size must be non-negative")
}
diff --git a/backend/internal/config/config_test.go b/backend/internal/config/config_test.go
index b0402a3b..e3b592e2 100644
--- a/backend/internal/config/config_test.go
+++ b/backend/internal/config/config_test.go
@@ -6,6 +6,7 @@ import (
"time"
"github.com/spf13/viper"
+ "github.com/stretchr/testify/require"
)
func resetViperWithJWTSecret(t *testing.T) {
@@ -75,6 +76,103 @@ func TestLoadDefaultSchedulingConfig(t *testing.T) {
}
}
+func TestLoadDefaultOpenAIWSConfig(t *testing.T) {
+ resetViperWithJWTSecret(t)
+
+ cfg, err := Load()
+ if err != nil {
+ t.Fatalf("Load() error: %v", err)
+ }
+
+ if !cfg.Gateway.OpenAIWS.Enabled {
+ t.Fatalf("Gateway.OpenAIWS.Enabled = false, want true")
+ }
+ if !cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 {
+ t.Fatalf("Gateway.OpenAIWS.ResponsesWebsocketsV2 = false, want true")
+ }
+ if cfg.Gateway.OpenAIWS.ResponsesWebsockets {
+ t.Fatalf("Gateway.OpenAIWS.ResponsesWebsockets = true, want false")
+ }
+ if !cfg.Gateway.OpenAIWS.DynamicMaxConnsByAccountConcurrencyEnabled {
+ t.Fatalf("Gateway.OpenAIWS.DynamicMaxConnsByAccountConcurrencyEnabled = false, want true")
+ }
+ if cfg.Gateway.OpenAIWS.OAuthMaxConnsFactor != 1.0 {
+ t.Fatalf("Gateway.OpenAIWS.OAuthMaxConnsFactor = %v, want 1.0", cfg.Gateway.OpenAIWS.OAuthMaxConnsFactor)
+ }
+ if cfg.Gateway.OpenAIWS.APIKeyMaxConnsFactor != 1.0 {
+ t.Fatalf("Gateway.OpenAIWS.APIKeyMaxConnsFactor = %v, want 1.0", cfg.Gateway.OpenAIWS.APIKeyMaxConnsFactor)
+ }
+ if cfg.Gateway.OpenAIWS.StickySessionTTLSeconds != 3600 {
+ t.Fatalf("Gateway.OpenAIWS.StickySessionTTLSeconds = %d, want 3600", cfg.Gateway.OpenAIWS.StickySessionTTLSeconds)
+ }
+ if !cfg.Gateway.OpenAIWS.SessionHashReadOldFallback {
+ t.Fatalf("Gateway.OpenAIWS.SessionHashReadOldFallback = false, want true")
+ }
+ if !cfg.Gateway.OpenAIWS.SessionHashDualWriteOld {
+ t.Fatalf("Gateway.OpenAIWS.SessionHashDualWriteOld = false, want true")
+ }
+ if !cfg.Gateway.OpenAIWS.MetadataBridgeEnabled {
+ t.Fatalf("Gateway.OpenAIWS.MetadataBridgeEnabled = false, want true")
+ }
+ if cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds != 3600 {
+ t.Fatalf("Gateway.OpenAIWS.StickyResponseIDTTLSeconds = %d, want 3600", cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds)
+ }
+ if cfg.Gateway.OpenAIWS.FallbackCooldownSeconds != 30 {
+ t.Fatalf("Gateway.OpenAIWS.FallbackCooldownSeconds = %d, want 30", cfg.Gateway.OpenAIWS.FallbackCooldownSeconds)
+ }
+ if cfg.Gateway.OpenAIWS.EventFlushBatchSize != 1 {
+ t.Fatalf("Gateway.OpenAIWS.EventFlushBatchSize = %d, want 1", cfg.Gateway.OpenAIWS.EventFlushBatchSize)
+ }
+ if cfg.Gateway.OpenAIWS.EventFlushIntervalMS != 10 {
+ t.Fatalf("Gateway.OpenAIWS.EventFlushIntervalMS = %d, want 10", cfg.Gateway.OpenAIWS.EventFlushIntervalMS)
+ }
+ if cfg.Gateway.OpenAIWS.PrewarmCooldownMS != 300 {
+ t.Fatalf("Gateway.OpenAIWS.PrewarmCooldownMS = %d, want 300", cfg.Gateway.OpenAIWS.PrewarmCooldownMS)
+ }
+ if cfg.Gateway.OpenAIWS.RetryBackoffInitialMS != 120 {
+ t.Fatalf("Gateway.OpenAIWS.RetryBackoffInitialMS = %d, want 120", cfg.Gateway.OpenAIWS.RetryBackoffInitialMS)
+ }
+ if cfg.Gateway.OpenAIWS.RetryBackoffMaxMS != 2000 {
+ t.Fatalf("Gateway.OpenAIWS.RetryBackoffMaxMS = %d, want 2000", cfg.Gateway.OpenAIWS.RetryBackoffMaxMS)
+ }
+ if cfg.Gateway.OpenAIWS.RetryJitterRatio != 0.2 {
+ t.Fatalf("Gateway.OpenAIWS.RetryJitterRatio = %v, want 0.2", cfg.Gateway.OpenAIWS.RetryJitterRatio)
+ }
+ if cfg.Gateway.OpenAIWS.RetryTotalBudgetMS != 5000 {
+ t.Fatalf("Gateway.OpenAIWS.RetryTotalBudgetMS = %d, want 5000", cfg.Gateway.OpenAIWS.RetryTotalBudgetMS)
+ }
+ if cfg.Gateway.OpenAIWS.PayloadLogSampleRate != 0.2 {
+ t.Fatalf("Gateway.OpenAIWS.PayloadLogSampleRate = %v, want 0.2", cfg.Gateway.OpenAIWS.PayloadLogSampleRate)
+ }
+ if !cfg.Gateway.OpenAIWS.StoreDisabledForceNewConn {
+ t.Fatalf("Gateway.OpenAIWS.StoreDisabledForceNewConn = false, want true")
+ }
+ if cfg.Gateway.OpenAIWS.StoreDisabledConnMode != "strict" {
+ t.Fatalf("Gateway.OpenAIWS.StoreDisabledConnMode = %q, want %q", cfg.Gateway.OpenAIWS.StoreDisabledConnMode, "strict")
+ }
+ if cfg.Gateway.OpenAIWS.ModeRouterV2Enabled {
+ t.Fatalf("Gateway.OpenAIWS.ModeRouterV2Enabled = true, want false")
+ }
+ if cfg.Gateway.OpenAIWS.IngressModeDefault != "shared" {
+ t.Fatalf("Gateway.OpenAIWS.IngressModeDefault = %q, want %q", cfg.Gateway.OpenAIWS.IngressModeDefault, "shared")
+ }
+}
+
+func TestLoadOpenAIWSStickyTTLCompatibility(t *testing.T) {
+ resetViperWithJWTSecret(t)
+ t.Setenv("GATEWAY_OPENAI_WS_STICKY_RESPONSE_ID_TTL_SECONDS", "0")
+ t.Setenv("GATEWAY_OPENAI_WS_STICKY_PREVIOUS_RESPONSE_TTL_SECONDS", "7200")
+
+ cfg, err := Load()
+ if err != nil {
+ t.Fatalf("Load() error: %v", err)
+ }
+
+ if cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds != 7200 {
+ t.Fatalf("StickyResponseIDTTLSeconds = %d, want 7200", cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds)
+ }
+}
+
func TestLoadDefaultIdempotencyConfig(t *testing.T) {
resetViperWithJWTSecret(t)
@@ -993,6 +1091,16 @@ func TestValidateConfigErrors(t *testing.T) {
mutate: func(c *Config) { c.Gateway.StreamKeepaliveInterval = 4 },
wantErr: "gateway.stream_keepalive_interval",
},
+ {
+ name: "gateway openai ws oauth max conns factor",
+ mutate: func(c *Config) { c.Gateway.OpenAIWS.OAuthMaxConnsFactor = 0 },
+ wantErr: "gateway.openai_ws.oauth_max_conns_factor",
+ },
+ {
+ name: "gateway openai ws apikey max conns factor",
+ mutate: func(c *Config) { c.Gateway.OpenAIWS.APIKeyMaxConnsFactor = 0 },
+ wantErr: "gateway.openai_ws.apikey_max_conns_factor",
+ },
{
name: "gateway stream data interval range",
mutate: func(c *Config) { c.Gateway.StreamDataIntervalTimeout = 5 },
@@ -1174,6 +1282,165 @@ func TestValidateConfigErrors(t *testing.T) {
}
}
+func TestValidateConfig_OpenAIWSRules(t *testing.T) {
+ buildValid := func(t *testing.T) *Config {
+ t.Helper()
+ resetViperWithJWTSecret(t)
+ cfg, err := Load()
+ require.NoError(t, err)
+ return cfg
+ }
+
+ t.Run("sticky response id ttl 兼容旧键回填", func(t *testing.T) {
+ cfg := buildValid(t)
+ cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 0
+ cfg.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds = 7200
+
+ require.NoError(t, cfg.Validate())
+ require.Equal(t, 7200, cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds)
+ })
+
+ cases := []struct {
+ name string
+ mutate func(*Config)
+ wantErr string
+ }{
+ {
+ name: "max_conns_per_account 必须为正数",
+ mutate: func(c *Config) { c.Gateway.OpenAIWS.MaxConnsPerAccount = 0 },
+ wantErr: "gateway.openai_ws.max_conns_per_account",
+ },
+ {
+ name: "min_idle_per_account 不能为负数",
+ mutate: func(c *Config) { c.Gateway.OpenAIWS.MinIdlePerAccount = -1 },
+ wantErr: "gateway.openai_ws.min_idle_per_account",
+ },
+ {
+ name: "max_idle_per_account 不能为负数",
+ mutate: func(c *Config) { c.Gateway.OpenAIWS.MaxIdlePerAccount = -1 },
+ wantErr: "gateway.openai_ws.max_idle_per_account",
+ },
+ {
+ name: "min_idle_per_account 不能大于 max_idle_per_account",
+ mutate: func(c *Config) {
+ c.Gateway.OpenAIWS.MinIdlePerAccount = 3
+ c.Gateway.OpenAIWS.MaxIdlePerAccount = 2
+ },
+ wantErr: "gateway.openai_ws.min_idle_per_account must be <= max_idle_per_account",
+ },
+ {
+ name: "max_idle_per_account 不能大于 max_conns_per_account",
+ mutate: func(c *Config) {
+ c.Gateway.OpenAIWS.MaxConnsPerAccount = 2
+ c.Gateway.OpenAIWS.MinIdlePerAccount = 1
+ c.Gateway.OpenAIWS.MaxIdlePerAccount = 3
+ },
+ wantErr: "gateway.openai_ws.max_idle_per_account must be <= max_conns_per_account",
+ },
+ {
+ name: "dial_timeout_seconds 必须为正数",
+ mutate: func(c *Config) { c.Gateway.OpenAIWS.DialTimeoutSeconds = 0 },
+ wantErr: "gateway.openai_ws.dial_timeout_seconds",
+ },
+ {
+ name: "read_timeout_seconds 必须为正数",
+ mutate: func(c *Config) { c.Gateway.OpenAIWS.ReadTimeoutSeconds = 0 },
+ wantErr: "gateway.openai_ws.read_timeout_seconds",
+ },
+ {
+ name: "write_timeout_seconds 必须为正数",
+ mutate: func(c *Config) { c.Gateway.OpenAIWS.WriteTimeoutSeconds = 0 },
+ wantErr: "gateway.openai_ws.write_timeout_seconds",
+ },
+ {
+ name: "pool_target_utilization 必须在 (0,1]",
+ mutate: func(c *Config) { c.Gateway.OpenAIWS.PoolTargetUtilization = 0 },
+ wantErr: "gateway.openai_ws.pool_target_utilization",
+ },
+ {
+ name: "queue_limit_per_conn 必须为正数",
+ mutate: func(c *Config) { c.Gateway.OpenAIWS.QueueLimitPerConn = 0 },
+ wantErr: "gateway.openai_ws.queue_limit_per_conn",
+ },
+ {
+ name: "fallback_cooldown_seconds 不能为负数",
+ mutate: func(c *Config) { c.Gateway.OpenAIWS.FallbackCooldownSeconds = -1 },
+ wantErr: "gateway.openai_ws.fallback_cooldown_seconds",
+ },
+ {
+ name: "store_disabled_conn_mode 必须为 strict|adaptive|off",
+ mutate: func(c *Config) { c.Gateway.OpenAIWS.StoreDisabledConnMode = "invalid" },
+ wantErr: "gateway.openai_ws.store_disabled_conn_mode",
+ },
+ {
+ name: "ingress_mode_default 必须为 off|shared|dedicated",
+ mutate: func(c *Config) { c.Gateway.OpenAIWS.IngressModeDefault = "invalid" },
+ wantErr: "gateway.openai_ws.ingress_mode_default",
+ },
+ {
+ name: "payload_log_sample_rate 必须在 [0,1] 范围内",
+ mutate: func(c *Config) { c.Gateway.OpenAIWS.PayloadLogSampleRate = 1.2 },
+ wantErr: "gateway.openai_ws.payload_log_sample_rate",
+ },
+ {
+ name: "retry_total_budget_ms 不能为负数",
+ mutate: func(c *Config) { c.Gateway.OpenAIWS.RetryTotalBudgetMS = -1 },
+ wantErr: "gateway.openai_ws.retry_total_budget_ms",
+ },
+ {
+ name: "lb_top_k 必须为正数",
+ mutate: func(c *Config) { c.Gateway.OpenAIWS.LBTopK = 0 },
+ wantErr: "gateway.openai_ws.lb_top_k",
+ },
+ {
+ name: "sticky_session_ttl_seconds 必须为正数",
+ mutate: func(c *Config) { c.Gateway.OpenAIWS.StickySessionTTLSeconds = 0 },
+ wantErr: "gateway.openai_ws.sticky_session_ttl_seconds",
+ },
+ {
+ name: "sticky_response_id_ttl_seconds 必须为正数",
+ mutate: func(c *Config) {
+ c.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 0
+ c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds = 0
+ },
+ wantErr: "gateway.openai_ws.sticky_response_id_ttl_seconds",
+ },
+ {
+ name: "sticky_previous_response_ttl_seconds 不能为负数",
+ mutate: func(c *Config) { c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds = -1 },
+ wantErr: "gateway.openai_ws.sticky_previous_response_ttl_seconds",
+ },
+ {
+ name: "scheduler_score_weights 不能为负数",
+ mutate: func(c *Config) { c.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = -0.1 },
+ wantErr: "gateway.openai_ws.scheduler_score_weights.* must be non-negative",
+ },
+ {
+ name: "scheduler_score_weights 不能全为 0",
+ mutate: func(c *Config) {
+ c.Gateway.OpenAIWS.SchedulerScoreWeights.Priority = 0
+ c.Gateway.OpenAIWS.SchedulerScoreWeights.Load = 0
+ c.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = 0
+ c.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 0
+ c.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 0
+ },
+ wantErr: "gateway.openai_ws.scheduler_score_weights must not all be zero",
+ },
+ }
+
+ for _, tc := range cases {
+ tc := tc
+ t.Run(tc.name, func(t *testing.T) {
+ cfg := buildValid(t)
+ tc.mutate(cfg)
+
+ err := cfg.Validate()
+ require.Error(t, err)
+ require.Contains(t, err.Error(), tc.wantErr)
+ })
+ }
+}
+
func TestValidateConfig_AutoScaleDisabledIgnoreAutoScaleFields(t *testing.T) {
resetViperWithJWTSecret(t)
cfg, err := Load()
diff --git a/backend/internal/domain/constants.go b/backend/internal/domain/constants.go
index c41aa65f..d7bb50fc 100644
--- a/backend/internal/domain/constants.go
+++ b/backend/internal/domain/constants.go
@@ -89,19 +89,24 @@ var DefaultAntigravityModelMapping = map[string]string{
"gemini-2.5-flash-thinking": "gemini-2.5-flash-thinking",
"gemini-2.5-pro": "gemini-2.5-pro",
// Gemini 3 白名单
- "gemini-3-flash": "gemini-3-flash",
- "gemini-3-pro-high": "gemini-3-pro-high",
- "gemini-3-pro-low": "gemini-3-pro-low",
- "gemini-3-pro-image": "gemini-3-pro-image",
+ "gemini-3-flash": "gemini-3-flash",
+ "gemini-3-pro-high": "gemini-3-pro-high",
+ "gemini-3-pro-low": "gemini-3-pro-low",
// Gemini 3 preview 映射
- "gemini-3-flash-preview": "gemini-3-flash",
- "gemini-3-pro-preview": "gemini-3-pro-high",
- "gemini-3-pro-image-preview": "gemini-3-pro-image",
+ "gemini-3-flash-preview": "gemini-3-flash",
+ "gemini-3-pro-preview": "gemini-3-pro-high",
// Gemini 3.1 白名单
"gemini-3.1-pro-high": "gemini-3.1-pro-high",
"gemini-3.1-pro-low": "gemini-3.1-pro-low",
// Gemini 3.1 preview 映射
"gemini-3.1-pro-preview": "gemini-3.1-pro-high",
+ // Gemini 3.1 image 白名单
+ "gemini-3.1-flash-image": "gemini-3.1-flash-image",
+ // Gemini 3.1 image preview 映射
+ "gemini-3.1-flash-image-preview": "gemini-3.1-flash-image",
+ // Gemini 3 image 兼容映射(向 3.1 image 迁移)
+ "gemini-3-pro-image": "gemini-3.1-flash-image",
+ "gemini-3-pro-image-preview": "gemini-3.1-flash-image",
// 其他官方模型
"gpt-oss-120b-medium": "gpt-oss-120b-medium",
"tab_flash_lite_preview": "tab_flash_lite_preview",
diff --git a/backend/internal/domain/constants_test.go b/backend/internal/domain/constants_test.go
new file mode 100644
index 00000000..29605ac6
--- /dev/null
+++ b/backend/internal/domain/constants_test.go
@@ -0,0 +1,24 @@
+package domain
+
+import "testing"
+
+func TestDefaultAntigravityModelMapping_ImageCompatibilityAliases(t *testing.T) {
+ t.Parallel()
+
+ cases := map[string]string{
+ "gemini-3.1-flash-image": "gemini-3.1-flash-image",
+ "gemini-3.1-flash-image-preview": "gemini-3.1-flash-image",
+ "gemini-3-pro-image": "gemini-3.1-flash-image",
+ "gemini-3-pro-image-preview": "gemini-3.1-flash-image",
+ }
+
+ for from, want := range cases {
+ got, ok := DefaultAntigravityModelMapping[from]
+ if !ok {
+ t.Fatalf("expected mapping for %q to exist", from)
+ }
+ if got != want {
+ t.Fatalf("unexpected mapping for %q: got %q want %q", from, got, want)
+ }
+ }
+}
diff --git a/backend/internal/handler/admin/account_data_handler_test.go b/backend/internal/handler/admin/account_data_handler_test.go
index c8b04c2a..285033a1 100644
--- a/backend/internal/handler/admin/account_data_handler_test.go
+++ b/backend/internal/handler/admin/account_data_handler_test.go
@@ -64,6 +64,7 @@ func setupAccountDataRouter() (*gin.Engine, *stubAdminService) {
nil,
nil,
nil,
+ nil,
)
router.GET("/api/v1/admin/accounts/data", h.ExportData)
diff --git a/backend/internal/handler/admin/account_handler.go b/backend/internal/handler/admin/account_handler.go
index df82476c..98ead284 100644
--- a/backend/internal/handler/admin/account_handler.go
+++ b/backend/internal/handler/admin/account_handler.go
@@ -16,6 +16,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/domain"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
@@ -52,6 +53,7 @@ type AccountHandler struct {
concurrencyService *service.ConcurrencyService
crsSyncService *service.CRSSyncService
sessionLimitCache service.SessionLimitCache
+ rpmCache service.RPMCache
tokenCacheInvalidator service.TokenCacheInvalidator
}
@@ -68,6 +70,7 @@ func NewAccountHandler(
concurrencyService *service.ConcurrencyService,
crsSyncService *service.CRSSyncService,
sessionLimitCache service.SessionLimitCache,
+ rpmCache service.RPMCache,
tokenCacheInvalidator service.TokenCacheInvalidator,
) *AccountHandler {
return &AccountHandler{
@@ -82,6 +85,7 @@ func NewAccountHandler(
concurrencyService: concurrencyService,
crsSyncService: crsSyncService,
sessionLimitCache: sessionLimitCache,
+ rpmCache: rpmCache,
tokenCacheInvalidator: tokenCacheInvalidator,
}
}
@@ -153,6 +157,7 @@ type AccountWithConcurrency struct {
// 以下字段仅对 Anthropic OAuth/SetupToken 账号有效,且仅在启用相应功能时返回
CurrentWindowCost *float64 `json:"current_window_cost,omitempty"` // 当前窗口费用
ActiveSessions *int `json:"active_sessions,omitempty"` // 当前活跃会话数
+ CurrentRPM *int `json:"current_rpm,omitempty"` // 当前分钟 RPM 计数
}
func (h *AccountHandler) buildAccountResponseWithRuntime(ctx context.Context, account *service.Account) AccountWithConcurrency {
@@ -188,6 +193,12 @@ func (h *AccountHandler) buildAccountResponseWithRuntime(ctx context.Context, ac
}
}
}
+
+ if h.rpmCache != nil && account.GetBaseRPM() > 0 {
+ if rpm, err := h.rpmCache.GetRPM(ctx, account.ID); err == nil {
+ item.CurrentRPM = &rpm
+ }
+ }
}
return item
@@ -230,9 +241,10 @@ func (h *AccountHandler) List(c *gin.Context) {
concurrencyCounts = make(map[int64]int)
}
- // 识别需要查询窗口费用和会话数的账号(Anthropic OAuth/SetupToken 且启用了相应功能)
+ // 识别需要查询窗口费用、会话数和 RPM 的账号(Anthropic OAuth/SetupToken 且启用了相应功能)
windowCostAccountIDs := make([]int64, 0)
sessionLimitAccountIDs := make([]int64, 0)
+ rpmAccountIDs := make([]int64, 0)
sessionIdleTimeouts := make(map[int64]time.Duration) // 各账号的会话空闲超时配置
for i := range accounts {
acc := &accounts[i]
@@ -244,12 +256,24 @@ func (h *AccountHandler) List(c *gin.Context) {
sessionLimitAccountIDs = append(sessionLimitAccountIDs, acc.ID)
sessionIdleTimeouts[acc.ID] = time.Duration(acc.GetSessionIdleTimeoutMinutes()) * time.Minute
}
+ if acc.GetBaseRPM() > 0 {
+ rpmAccountIDs = append(rpmAccountIDs, acc.ID)
+ }
}
}
- // 并行获取窗口费用和活跃会话数
+ // 并行获取窗口费用、活跃会话数和 RPM 计数
var windowCosts map[int64]float64
var activeSessions map[int64]int
+ var rpmCounts map[int64]int
+
+ // 获取 RPM 计数(批量查询)
+ if len(rpmAccountIDs) > 0 && h.rpmCache != nil {
+ rpmCounts, _ = h.rpmCache.GetRPMBatch(c.Request.Context(), rpmAccountIDs)
+ if rpmCounts == nil {
+ rpmCounts = make(map[int64]int)
+ }
+ }
// 获取活跃会话数(批量查询,传入各账号的 idleTimeout 配置)
if len(sessionLimitAccountIDs) > 0 && h.sessionLimitCache != nil {
@@ -310,6 +334,13 @@ func (h *AccountHandler) List(c *gin.Context) {
}
}
+ // 添加 RPM 计数(仅当启用时)
+ if rpmCounts != nil {
+ if rpm, ok := rpmCounts[acc.ID]; ok {
+ item.CurrentRPM = &rpm
+ }
+ }
+
result[i] = item
}
@@ -452,6 +483,8 @@ func (h *AccountHandler) Create(c *gin.Context) {
response.BadRequest(c, "rate_multiplier must be >= 0")
return
}
+ // base_rpm 输入校验:负值归零,超过 10000 截断
+ sanitizeExtraBaseRPM(req.Extra)
// 确定是否跳过混合渠道检查
skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk
@@ -521,6 +554,8 @@ func (h *AccountHandler) Update(c *gin.Context) {
response.BadRequest(c, "rate_multiplier must be >= 0")
return
}
+ // base_rpm 输入校验:负值归零,超过 10000 截断
+ sanitizeExtraBaseRPM(req.Extra)
// 确定是否跳过混合渠道检查
skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk
@@ -903,6 +938,9 @@ func (h *AccountHandler) BatchCreate(c *gin.Context) {
continue
}
+ // base_rpm 输入校验:负值归零,超过 10000 截断
+ sanitizeExtraBaseRPM(item.Extra)
+
skipCheck := item.ConfirmMixedChannelRisk != nil && *item.ConfirmMixedChannelRisk
account, err := h.adminService.CreateAccount(ctx, &service.CreateAccountInput{
@@ -1047,6 +1085,8 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
response.BadRequest(c, "rate_multiplier must be >= 0")
return
}
+ // base_rpm 输入校验:负值归零,超过 10000 截断
+ sanitizeExtraBaseRPM(req.Extra)
// 确定是否跳过混合渠道检查
skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk
@@ -1082,6 +1122,14 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
SkipMixedChannelCheck: skipCheck,
})
if err != nil {
+ var mixedErr *service.MixedChannelError
+ if errors.As(err, &mixedErr) {
+ c.JSON(409, gin.H{
+ "error": "mixed_channel_warning",
+ "message": mixedErr.Error(),
+ })
+ return
+ }
response.ErrorFrom(c, err)
return
}
@@ -1336,6 +1384,34 @@ func (h *AccountHandler) GetTodayStats(c *gin.Context) {
response.Success(c, stats)
}
+// BatchTodayStatsRequest 批量今日统计请求体。
+type BatchTodayStatsRequest struct {
+ AccountIDs []int64 `json:"account_ids" binding:"required"`
+}
+
+// GetBatchTodayStats 批量获取多个账号的今日统计。
+// POST /api/v1/admin/accounts/today-stats/batch
+func (h *AccountHandler) GetBatchTodayStats(c *gin.Context) {
+ var req BatchTodayStatsRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ if len(req.AccountIDs) == 0 {
+ response.Success(c, gin.H{"stats": map[string]any{}})
+ return
+ }
+
+ stats, err := h.accountUsageService.GetTodayStatsBatch(c.Request.Context(), req.AccountIDs)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ response.Success(c, gin.H{"stats": stats})
+}
+
// SetSchedulableRequest represents the request body for setting schedulable status
type SetSchedulableRequest struct {
Schedulable bool `json:"schedulable"`
@@ -1459,32 +1535,8 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
// Handle Antigravity accounts: return Claude + Gemini models
if account.Platform == service.PlatformAntigravity {
- // Antigravity 支持 Claude 和部分 Gemini 模型
- type UnifiedModel struct {
- ID string `json:"id"`
- Type string `json:"type"`
- DisplayName string `json:"display_name"`
- }
-
- var models []UnifiedModel
-
- // 添加 Claude 模型
- for _, m := range claude.DefaultModels {
- models = append(models, UnifiedModel{
- ID: m.ID,
- Type: m.Type,
- DisplayName: m.DisplayName,
- })
- }
-
- // 添加 Gemini 3 系列模型用于测试
- geminiTestModels := []UnifiedModel{
- {ID: "gemini-3-flash", Type: "model", DisplayName: "Gemini 3 Flash"},
- {ID: "gemini-3-pro-preview", Type: "model", DisplayName: "Gemini 3 Pro Preview"},
- }
- models = append(models, geminiTestModels...)
-
- response.Success(c, models)
+ // 直接复用 antigravity.DefaultModels(),与 /v1/models 端点保持同步
+ response.Success(c, antigravity.DefaultModels())
return
}
@@ -1701,3 +1753,22 @@ func (h *AccountHandler) BatchRefreshTier(c *gin.Context) {
func (h *AccountHandler) GetAntigravityDefaultModelMapping(c *gin.Context) {
response.Success(c, domain.DefaultAntigravityModelMapping)
}
+
+// sanitizeExtraBaseRPM 对 extra map 中的 base_rpm 值进行范围校验和归一化。
+// 负值归零,超过 10000 截断为 10000。extra 为 nil 或不含 base_rpm 时无操作。
+func sanitizeExtraBaseRPM(extra map[string]any) {
+ if extra == nil {
+ return
+ }
+ raw, ok := extra["base_rpm"]
+ if !ok {
+ return
+ }
+ v := service.ParseExtraInt(raw)
+ if v < 0 {
+ v = 0
+ } else if v > 10000 {
+ v = 10000
+ }
+ extra["base_rpm"] = v
+}
diff --git a/backend/internal/handler/admin/account_handler_mixed_channel_test.go b/backend/internal/handler/admin/account_handler_mixed_channel_test.go
index ad004844..24ec5bcf 100644
--- a/backend/internal/handler/admin/account_handler_mixed_channel_test.go
+++ b/backend/internal/handler/admin/account_handler_mixed_channel_test.go
@@ -15,10 +15,11 @@ import (
func setupAccountMixedChannelRouter(adminSvc *stubAdminService) *gin.Engine {
gin.SetMode(gin.TestMode)
router := gin.New()
- accountHandler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
+ accountHandler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
router.POST("/api/v1/admin/accounts/check-mixed-channel", accountHandler.CheckMixedChannel)
router.POST("/api/v1/admin/accounts", accountHandler.Create)
router.PUT("/api/v1/admin/accounts/:id", accountHandler.Update)
+ router.POST("/api/v1/admin/accounts/bulk-update", accountHandler.BulkUpdate)
return router
}
@@ -145,3 +146,53 @@ func TestAccountHandlerUpdateMixedChannelConflictSimplifiedResponse(t *testing.T
require.False(t, hasDetails)
require.False(t, hasRequireConfirmation)
}
+
+func TestAccountHandlerBulkUpdateMixedChannelConflict(t *testing.T) {
+ adminSvc := newStubAdminService()
+ adminSvc.bulkUpdateAccountErr = &service.MixedChannelError{
+ GroupID: 27,
+ GroupName: "claude-max",
+ CurrentPlatform: "Antigravity",
+ OtherPlatform: "Anthropic",
+ }
+ router := setupAccountMixedChannelRouter(adminSvc)
+
+ body, _ := json.Marshal(map[string]any{
+ "account_ids": []int64{1, 2, 3},
+ "group_ids": []int64{27},
+ })
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/bulk-update", bytes.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+ router.ServeHTTP(rec, req)
+
+ require.Equal(t, http.StatusConflict, rec.Code)
+ var resp map[string]any
+ require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
+ require.Equal(t, "mixed_channel_warning", resp["error"])
+ require.Contains(t, resp["message"], "claude-max")
+}
+
+func TestAccountHandlerBulkUpdateMixedChannelConfirmSkips(t *testing.T) {
+ adminSvc := newStubAdminService()
+ router := setupAccountMixedChannelRouter(adminSvc)
+
+ body, _ := json.Marshal(map[string]any{
+ "account_ids": []int64{1, 2},
+ "group_ids": []int64{27},
+ "confirm_mixed_channel_risk": true,
+ })
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/bulk-update", bytes.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+ router.ServeHTTP(rec, req)
+
+ require.Equal(t, http.StatusOK, rec.Code)
+ var resp map[string]any
+ require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
+ require.Equal(t, float64(0), resp["code"])
+ data, ok := resp["data"].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, float64(2), data["success"])
+ require.Equal(t, float64(0), data["failed"])
+}
diff --git a/backend/internal/handler/admin/account_handler_passthrough_test.go b/backend/internal/handler/admin/account_handler_passthrough_test.go
index d09cccd6..d86501c0 100644
--- a/backend/internal/handler/admin/account_handler_passthrough_test.go
+++ b/backend/internal/handler/admin/account_handler_passthrough_test.go
@@ -28,6 +28,7 @@ func TestAccountHandler_Create_AnthropicAPIKeyPassthroughExtraForwarded(t *testi
nil,
nil,
nil,
+ nil,
)
router := gin.New()
diff --git a/backend/internal/handler/admin/admin_basic_handlers_test.go b/backend/internal/handler/admin/admin_basic_handlers_test.go
index aeb4097f..4de10d3e 100644
--- a/backend/internal/handler/admin/admin_basic_handlers_test.go
+++ b/backend/internal/handler/admin/admin_basic_handlers_test.go
@@ -19,7 +19,7 @@ func setupAdminRouter() (*gin.Engine, *stubAdminService) {
userHandler := NewUserHandler(adminSvc, nil)
groupHandler := NewGroupHandler(adminSvc)
proxyHandler := NewProxyHandler(adminSvc)
- redeemHandler := NewRedeemHandler(adminSvc)
+ redeemHandler := NewRedeemHandler(adminSvc, nil)
router.GET("/api/v1/admin/users", userHandler.List)
router.GET("/api/v1/admin/users/:id", userHandler.GetByID)
diff --git a/backend/internal/handler/admin/admin_service_stub_test.go b/backend/internal/handler/admin/admin_service_stub_test.go
index 848122e4..f3b99ddb 100644
--- a/backend/internal/handler/admin/admin_service_stub_test.go
+++ b/backend/internal/handler/admin/admin_service_stub_test.go
@@ -10,22 +10,23 @@ import (
)
type stubAdminService struct {
- users []service.User
- apiKeys []service.APIKey
- groups []service.Group
- accounts []service.Account
- proxies []service.Proxy
- proxyCounts []service.ProxyWithAccountCount
- redeems []service.RedeemCode
- createdAccounts []*service.CreateAccountInput
- createdProxies []*service.CreateProxyInput
- updatedProxyIDs []int64
- updatedProxies []*service.UpdateProxyInput
- testedProxyIDs []int64
- createAccountErr error
- updateAccountErr error
- checkMixedErr error
- lastMixedCheck struct {
+ users []service.User
+ apiKeys []service.APIKey
+ groups []service.Group
+ accounts []service.Account
+ proxies []service.Proxy
+ proxyCounts []service.ProxyWithAccountCount
+ redeems []service.RedeemCode
+ createdAccounts []*service.CreateAccountInput
+ createdProxies []*service.CreateProxyInput
+ updatedProxyIDs []int64
+ updatedProxies []*service.UpdateProxyInput
+ testedProxyIDs []int64
+ createAccountErr error
+ updateAccountErr error
+ bulkUpdateAccountErr error
+ checkMixedErr error
+ lastMixedCheck struct {
accountID int64
platform string
groupIDs []int64
@@ -235,7 +236,10 @@ func (s *stubAdminService) SetAccountSchedulable(ctx context.Context, id int64,
}
func (s *stubAdminService) BulkUpdateAccounts(ctx context.Context, input *service.BulkUpdateAccountsInput) (*service.BulkUpdateAccountsResult, error) {
- return &service.BulkUpdateAccountsResult{Success: 1, Failed: 0, SuccessIDs: []int64{1}}, nil
+ if s.bulkUpdateAccountErr != nil {
+ return nil, s.bulkUpdateAccountErr
+ }
+ return &service.BulkUpdateAccountsResult{Success: len(input.AccountIDs), Failed: 0, SuccessIDs: input.AccountIDs}, nil
}
func (s *stubAdminService) CheckMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error {
@@ -403,5 +407,23 @@ func (s *stubAdminService) UpdateGroupSortOrders(ctx context.Context, updates []
return nil
}
+func (s *stubAdminService) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID int64, groupID *int64) (*service.AdminUpdateAPIKeyGroupIDResult, error) {
+ for i := range s.apiKeys {
+ if s.apiKeys[i].ID == keyID {
+ k := s.apiKeys[i]
+ if groupID != nil {
+ if *groupID == 0 {
+ k.GroupID = nil
+ } else {
+ gid := *groupID
+ k.GroupID = &gid
+ }
+ }
+ return &service.AdminUpdateAPIKeyGroupIDResult{APIKey: &k}, nil
+ }
+ }
+ return nil, service.ErrAPIKeyNotFound
+}
+
// Ensure stub implements interface.
var _ service.AdminService = (*stubAdminService)(nil)
diff --git a/backend/internal/handler/admin/apikey_handler.go b/backend/internal/handler/admin/apikey_handler.go
new file mode 100644
index 00000000..8dd245a4
--- /dev/null
+++ b/backend/internal/handler/admin/apikey_handler.go
@@ -0,0 +1,63 @@
+package admin
+
+import (
+ "strconv"
+
+ "github.com/Wei-Shaw/sub2api/internal/handler/dto"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/response"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+
+ "github.com/gin-gonic/gin"
+)
+
+// AdminAPIKeyHandler handles admin API key management
+type AdminAPIKeyHandler struct {
+ adminService service.AdminService
+}
+
+// NewAdminAPIKeyHandler creates a new admin API key handler
+func NewAdminAPIKeyHandler(adminService service.AdminService) *AdminAPIKeyHandler {
+ return &AdminAPIKeyHandler{
+ adminService: adminService,
+ }
+}
+
+// AdminUpdateAPIKeyGroupRequest represents the request to update an API key's group
+type AdminUpdateAPIKeyGroupRequest struct {
+ GroupID *int64 `json:"group_id"` // nil=不修改, 0=解绑, >0=绑定到目标分组
+}
+
+// UpdateGroup handles updating an API key's group binding
+// PUT /api/v1/admin/api-keys/:id
+func (h *AdminAPIKeyHandler) UpdateGroup(c *gin.Context) {
+ keyID, err := strconv.ParseInt(c.Param("id"), 10, 64)
+ if err != nil {
+ response.BadRequest(c, "Invalid API key ID")
+ return
+ }
+
+ var req AdminUpdateAPIKeyGroupRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ result, err := h.adminService.AdminUpdateAPIKeyGroupID(c.Request.Context(), keyID, req.GroupID)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ resp := struct {
+ APIKey *dto.APIKey `json:"api_key"`
+ AutoGrantedGroupAccess bool `json:"auto_granted_group_access"`
+ GrantedGroupID *int64 `json:"granted_group_id,omitempty"`
+ GrantedGroupName string `json:"granted_group_name,omitempty"`
+ }{
+ APIKey: dto.APIKeyFromService(result.APIKey),
+ AutoGrantedGroupAccess: result.AutoGrantedGroupAccess,
+ GrantedGroupID: result.GrantedGroupID,
+ GrantedGroupName: result.GrantedGroupName,
+ }
+ response.Success(c, resp)
+}
diff --git a/backend/internal/handler/admin/apikey_handler_test.go b/backend/internal/handler/admin/apikey_handler_test.go
new file mode 100644
index 00000000..bf128b18
--- /dev/null
+++ b/backend/internal/handler/admin/apikey_handler_test.go
@@ -0,0 +1,202 @@
+package admin
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "errors"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+)
+
+func setupAPIKeyHandler(adminSvc service.AdminService) *gin.Engine {
+ gin.SetMode(gin.TestMode)
+ router := gin.New()
+ h := NewAdminAPIKeyHandler(adminSvc)
+ router.PUT("/api/v1/admin/api-keys/:id", h.UpdateGroup)
+ return router
+}
+
+func TestAdminAPIKeyHandler_UpdateGroup_InvalidID(t *testing.T) {
+ router := setupAPIKeyHandler(newStubAdminService())
+ body := `{"group_id": 2}`
+
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/abc", bytes.NewBufferString(body))
+ req.Header.Set("Content-Type", "application/json")
+ router.ServeHTTP(rec, req)
+
+ require.Equal(t, http.StatusBadRequest, rec.Code)
+ require.Contains(t, rec.Body.String(), "Invalid API key ID")
+}
+
+func TestAdminAPIKeyHandler_UpdateGroup_InvalidJSON(t *testing.T) {
+ router := setupAPIKeyHandler(newStubAdminService())
+
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/10", bytes.NewBufferString(`{bad json`))
+ req.Header.Set("Content-Type", "application/json")
+ router.ServeHTTP(rec, req)
+
+ require.Equal(t, http.StatusBadRequest, rec.Code)
+ require.Contains(t, rec.Body.String(), "Invalid request")
+}
+
+func TestAdminAPIKeyHandler_UpdateGroup_KeyNotFound(t *testing.T) {
+ router := setupAPIKeyHandler(newStubAdminService())
+ body := `{"group_id": 2}`
+
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/999", bytes.NewBufferString(body))
+ req.Header.Set("Content-Type", "application/json")
+ router.ServeHTTP(rec, req)
+
+ // ErrAPIKeyNotFound maps to 404
+ require.Equal(t, http.StatusNotFound, rec.Code)
+}
+
+func TestAdminAPIKeyHandler_UpdateGroup_BindGroup(t *testing.T) {
+ router := setupAPIKeyHandler(newStubAdminService())
+ body := `{"group_id": 2}`
+
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/10", bytes.NewBufferString(body))
+ req.Header.Set("Content-Type", "application/json")
+ router.ServeHTTP(rec, req)
+
+ require.Equal(t, http.StatusOK, rec.Code)
+
+ var resp struct {
+ Code int `json:"code"`
+ Data json.RawMessage `json:"data"`
+ }
+ require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
+ require.Equal(t, 0, resp.Code)
+
+ var data struct {
+ APIKey struct {
+ ID int64 `json:"id"`
+ GroupID *int64 `json:"group_id"`
+ } `json:"api_key"`
+ AutoGrantedGroupAccess bool `json:"auto_granted_group_access"`
+ }
+ require.NoError(t, json.Unmarshal(resp.Data, &data))
+ require.Equal(t, int64(10), data.APIKey.ID)
+ require.NotNil(t, data.APIKey.GroupID)
+ require.Equal(t, int64(2), *data.APIKey.GroupID)
+}
+
+func TestAdminAPIKeyHandler_UpdateGroup_Unbind(t *testing.T) {
+ svc := newStubAdminService()
+ gid := int64(2)
+ svc.apiKeys[0].GroupID = &gid
+ router := setupAPIKeyHandler(svc)
+ body := `{"group_id": 0}`
+
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/10", bytes.NewBufferString(body))
+ req.Header.Set("Content-Type", "application/json")
+ router.ServeHTTP(rec, req)
+
+ require.Equal(t, http.StatusOK, rec.Code)
+
+ var resp struct {
+ Data struct {
+ APIKey struct {
+ GroupID *int64 `json:"group_id"`
+ } `json:"api_key"`
+ } `json:"data"`
+ }
+ require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
+ require.Nil(t, resp.Data.APIKey.GroupID)
+}
+
+func TestAdminAPIKeyHandler_UpdateGroup_ServiceError(t *testing.T) {
+ svc := &failingUpdateGroupService{
+ stubAdminService: newStubAdminService(),
+ err: errors.New("internal failure"),
+ }
+ router := setupAPIKeyHandler(svc)
+ body := `{"group_id": 2}`
+
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/10", bytes.NewBufferString(body))
+ req.Header.Set("Content-Type", "application/json")
+ router.ServeHTTP(rec, req)
+
+ require.Equal(t, http.StatusInternalServerError, rec.Code)
+}
+
+// H2: empty body → group_id is nil → no-op, returns original key
+func TestAdminAPIKeyHandler_UpdateGroup_EmptyBody_NoChange(t *testing.T) {
+ router := setupAPIKeyHandler(newStubAdminService())
+
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/10", bytes.NewBufferString(`{}`))
+ req.Header.Set("Content-Type", "application/json")
+ router.ServeHTTP(rec, req)
+
+ require.Equal(t, http.StatusOK, rec.Code)
+
+ var resp struct {
+ Code int `json:"code"`
+ Data struct {
+ APIKey struct {
+ ID int64 `json:"id"`
+ } `json:"api_key"`
+ } `json:"data"`
+ }
+ require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
+ require.Equal(t, 0, resp.Code)
+ require.Equal(t, int64(10), resp.Data.APIKey.ID)
+}
+
+// M2: service returns GROUP_NOT_ACTIVE → handler maps to 400
+func TestAdminAPIKeyHandler_UpdateGroup_GroupNotActive(t *testing.T) {
+ svc := &failingUpdateGroupService{
+ stubAdminService: newStubAdminService(),
+ err: infraerrors.BadRequest("GROUP_NOT_ACTIVE", "target group is not active"),
+ }
+ router := setupAPIKeyHandler(svc)
+
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/10", bytes.NewBufferString(`{"group_id": 5}`))
+ req.Header.Set("Content-Type", "application/json")
+ router.ServeHTTP(rec, req)
+
+ require.Equal(t, http.StatusBadRequest, rec.Code)
+ require.Contains(t, rec.Body.String(), "GROUP_NOT_ACTIVE")
+}
+
+// M2: service returns INVALID_GROUP_ID → handler maps to 400
+func TestAdminAPIKeyHandler_UpdateGroup_NegativeGroupID(t *testing.T) {
+ svc := &failingUpdateGroupService{
+ stubAdminService: newStubAdminService(),
+ err: infraerrors.BadRequest("INVALID_GROUP_ID", "group_id must be non-negative"),
+ }
+ router := setupAPIKeyHandler(svc)
+
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/10", bytes.NewBufferString(`{"group_id": -5}`))
+ req.Header.Set("Content-Type", "application/json")
+ router.ServeHTTP(rec, req)
+
+ require.Equal(t, http.StatusBadRequest, rec.Code)
+ require.Contains(t, rec.Body.String(), "INVALID_GROUP_ID")
+}
+
+// failingUpdateGroupService overrides AdminUpdateAPIKeyGroupID to return an error.
+type failingUpdateGroupService struct {
+ *stubAdminService
+ err error
+}
+
+func (f *failingUpdateGroupService) AdminUpdateAPIKeyGroupID(_ context.Context, _ int64, _ *int64) (*service.AdminUpdateAPIKeyGroupIDResult, error) {
+ return nil, f.err
+}
diff --git a/backend/internal/handler/admin/batch_update_credentials_test.go b/backend/internal/handler/admin/batch_update_credentials_test.go
index c8185735..0b1b6691 100644
--- a/backend/internal/handler/admin/batch_update_credentials_test.go
+++ b/backend/internal/handler/admin/batch_update_credentials_test.go
@@ -36,7 +36,7 @@ func (f *failingAdminService) UpdateAccount(ctx context.Context, id int64, input
func setupAccountHandlerWithService(adminSvc service.AdminService) (*gin.Engine, *AccountHandler) {
gin.SetMode(gin.TestMode)
router := gin.New()
- handler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
+ handler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
router.POST("/api/v1/admin/accounts/batch-update-credentials", handler.BatchUpdateCredentials)
return router, handler
}
diff --git a/backend/internal/handler/admin/dashboard_handler.go b/backend/internal/handler/admin/dashboard_handler.go
index fab66c04..1d48c653 100644
--- a/backend/internal/handler/admin/dashboard_handler.go
+++ b/backend/internal/handler/admin/dashboard_handler.go
@@ -3,6 +3,7 @@ package admin
import (
"errors"
"strconv"
+ "strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
@@ -186,7 +187,7 @@ func (h *DashboardHandler) GetRealtimeMetrics(c *gin.Context) {
// GetUsageTrend handles getting usage trend data
// GET /api/v1/admin/dashboard/trend
-// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), user_id, api_key_id, model, account_id, group_id, stream, billing_type
+// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), user_id, api_key_id, model, account_id, group_id, request_type, stream, billing_type
func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
startTime, endTime := parseTimeRange(c)
granularity := c.DefaultQuery("granularity", "day")
@@ -194,6 +195,7 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
// Parse optional filter params
var userID, apiKeyID, accountID, groupID int64
var model string
+ var requestType *int16
var stream *bool
var billingType *int8
@@ -220,9 +222,20 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
if modelStr := c.Query("model"); modelStr != "" {
model = modelStr
}
- if streamStr := c.Query("stream"); streamStr != "" {
+ if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" {
+ parsed, err := service.ParseUsageRequestType(requestTypeStr)
+ if err != nil {
+ response.BadRequest(c, err.Error())
+ return
+ }
+ value := int16(parsed)
+ requestType = &value
+ } else if streamStr := c.Query("stream"); streamStr != "" {
if streamVal, err := strconv.ParseBool(streamStr); err == nil {
stream = &streamVal
+ } else {
+ response.BadRequest(c, "Invalid stream value, use true or false")
+ return
}
}
if billingTypeStr := c.Query("billing_type"); billingTypeStr != "" {
@@ -235,7 +248,7 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
}
}
- trend, err := h.dashboardService.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, stream, billingType)
+ trend, err := h.dashboardService.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType)
if err != nil {
response.Error(c, 500, "Failed to get usage trend")
return
@@ -251,12 +264,13 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
// GetModelStats handles getting model usage statistics
// GET /api/v1/admin/dashboard/models
-// Query params: start_date, end_date (YYYY-MM-DD), user_id, api_key_id, account_id, group_id, stream, billing_type
+// Query params: start_date, end_date (YYYY-MM-DD), user_id, api_key_id, account_id, group_id, request_type, stream, billing_type
func (h *DashboardHandler) GetModelStats(c *gin.Context) {
startTime, endTime := parseTimeRange(c)
// Parse optional filter params
var userID, apiKeyID, accountID, groupID int64
+ var requestType *int16
var stream *bool
var billingType *int8
@@ -280,9 +294,20 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
groupID = id
}
}
- if streamStr := c.Query("stream"); streamStr != "" {
+ if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" {
+ parsed, err := service.ParseUsageRequestType(requestTypeStr)
+ if err != nil {
+ response.BadRequest(c, err.Error())
+ return
+ }
+ value := int16(parsed)
+ requestType = &value
+ } else if streamStr := c.Query("stream"); streamStr != "" {
if streamVal, err := strconv.ParseBool(streamStr); err == nil {
stream = &streamVal
+ } else {
+ response.BadRequest(c, "Invalid stream value, use true or false")
+ return
}
}
if billingTypeStr := c.Query("billing_type"); billingTypeStr != "" {
@@ -295,7 +320,7 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
}
}
- stats, err := h.dashboardService.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, stream, billingType)
+ stats, err := h.dashboardService.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
if err != nil {
response.Error(c, 500, "Failed to get model statistics")
return
@@ -308,6 +333,76 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
})
}
+// GetGroupStats handles getting group usage statistics
+// GET /api/v1/admin/dashboard/groups
+// Query params: start_date, end_date (YYYY-MM-DD), user_id, api_key_id, account_id, group_id, request_type, stream, billing_type
+func (h *DashboardHandler) GetGroupStats(c *gin.Context) {
+ startTime, endTime := parseTimeRange(c)
+
+ var userID, apiKeyID, accountID, groupID int64
+ var requestType *int16
+ var stream *bool
+ var billingType *int8
+
+ if userIDStr := c.Query("user_id"); userIDStr != "" {
+ if id, err := strconv.ParseInt(userIDStr, 10, 64); err == nil {
+ userID = id
+ }
+ }
+ if apiKeyIDStr := c.Query("api_key_id"); apiKeyIDStr != "" {
+ if id, err := strconv.ParseInt(apiKeyIDStr, 10, 64); err == nil {
+ apiKeyID = id
+ }
+ }
+ if accountIDStr := c.Query("account_id"); accountIDStr != "" {
+ if id, err := strconv.ParseInt(accountIDStr, 10, 64); err == nil {
+ accountID = id
+ }
+ }
+ if groupIDStr := c.Query("group_id"); groupIDStr != "" {
+ if id, err := strconv.ParseInt(groupIDStr, 10, 64); err == nil {
+ groupID = id
+ }
+ }
+ if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" {
+ parsed, err := service.ParseUsageRequestType(requestTypeStr)
+ if err != nil {
+ response.BadRequest(c, err.Error())
+ return
+ }
+ value := int16(parsed)
+ requestType = &value
+ } else if streamStr := c.Query("stream"); streamStr != "" {
+ if streamVal, err := strconv.ParseBool(streamStr); err == nil {
+ stream = &streamVal
+ } else {
+ response.BadRequest(c, "Invalid stream value, use true or false")
+ return
+ }
+ }
+ if billingTypeStr := c.Query("billing_type"); billingTypeStr != "" {
+ if v, err := strconv.ParseInt(billingTypeStr, 10, 8); err == nil {
+ bt := int8(v)
+ billingType = &bt
+ } else {
+ response.BadRequest(c, "Invalid billing_type")
+ return
+ }
+ }
+
+ stats, err := h.dashboardService.GetGroupStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
+ if err != nil {
+ response.Error(c, 500, "Failed to get group statistics")
+ return
+ }
+
+ response.Success(c, gin.H{
+ "groups": stats,
+ "start_date": startTime.Format("2006-01-02"),
+ "end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"),
+ })
+}
+
// GetAPIKeyUsageTrend handles getting API key usage trend data
// GET /api/v1/admin/dashboard/api-keys-trend
// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), limit (default 5)
diff --git a/backend/internal/handler/admin/dashboard_handler_request_type_test.go b/backend/internal/handler/admin/dashboard_handler_request_type_test.go
new file mode 100644
index 00000000..72af6b45
--- /dev/null
+++ b/backend/internal/handler/admin/dashboard_handler_request_type_test.go
@@ -0,0 +1,132 @@
+package admin
+
+import (
+ "context"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+)
+
+type dashboardUsageRepoCapture struct {
+ service.UsageLogRepository
+ trendRequestType *int16
+ trendStream *bool
+ modelRequestType *int16
+ modelStream *bool
+}
+
+func (s *dashboardUsageRepoCapture) GetUsageTrendWithFilters(
+ ctx context.Context,
+ startTime, endTime time.Time,
+ granularity string,
+ userID, apiKeyID, accountID, groupID int64,
+ model string,
+ requestType *int16,
+ stream *bool,
+ billingType *int8,
+) ([]usagestats.TrendDataPoint, error) {
+ s.trendRequestType = requestType
+ s.trendStream = stream
+ return []usagestats.TrendDataPoint{}, nil
+}
+
+func (s *dashboardUsageRepoCapture) GetModelStatsWithFilters(
+ ctx context.Context,
+ startTime, endTime time.Time,
+ userID, apiKeyID, accountID, groupID int64,
+ requestType *int16,
+ stream *bool,
+ billingType *int8,
+) ([]usagestats.ModelStat, error) {
+ s.modelRequestType = requestType
+ s.modelStream = stream
+ return []usagestats.ModelStat{}, nil
+}
+
+func newDashboardRequestTypeTestRouter(repo *dashboardUsageRepoCapture) *gin.Engine {
+ gin.SetMode(gin.TestMode)
+ dashboardSvc := service.NewDashboardService(repo, nil, nil, nil)
+ handler := NewDashboardHandler(dashboardSvc, nil)
+ router := gin.New()
+ router.GET("/admin/dashboard/trend", handler.GetUsageTrend)
+ router.GET("/admin/dashboard/models", handler.GetModelStats)
+ return router
+}
+
+func TestDashboardTrendRequestTypePriority(t *testing.T) {
+ repo := &dashboardUsageRepoCapture{}
+ router := newDashboardRequestTypeTestRouter(repo)
+
+ req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/trend?request_type=ws_v2&stream=bad", nil)
+ rec := httptest.NewRecorder()
+ router.ServeHTTP(rec, req)
+
+ require.Equal(t, http.StatusOK, rec.Code)
+ require.NotNil(t, repo.trendRequestType)
+ require.Equal(t, int16(service.RequestTypeWSV2), *repo.trendRequestType)
+ require.Nil(t, repo.trendStream)
+}
+
+func TestDashboardTrendInvalidRequestType(t *testing.T) {
+ repo := &dashboardUsageRepoCapture{}
+ router := newDashboardRequestTypeTestRouter(repo)
+
+ req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/trend?request_type=bad", nil)
+ rec := httptest.NewRecorder()
+ router.ServeHTTP(rec, req)
+
+ require.Equal(t, http.StatusBadRequest, rec.Code)
+}
+
+func TestDashboardTrendInvalidStream(t *testing.T) {
+ repo := &dashboardUsageRepoCapture{}
+ router := newDashboardRequestTypeTestRouter(repo)
+
+ req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/trend?stream=bad", nil)
+ rec := httptest.NewRecorder()
+ router.ServeHTTP(rec, req)
+
+ require.Equal(t, http.StatusBadRequest, rec.Code)
+}
+
+func TestDashboardModelStatsRequestTypePriority(t *testing.T) {
+ repo := &dashboardUsageRepoCapture{}
+ router := newDashboardRequestTypeTestRouter(repo)
+
+ req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/models?request_type=sync&stream=bad", nil)
+ rec := httptest.NewRecorder()
+ router.ServeHTTP(rec, req)
+
+ require.Equal(t, http.StatusOK, rec.Code)
+ require.NotNil(t, repo.modelRequestType)
+ require.Equal(t, int16(service.RequestTypeSync), *repo.modelRequestType)
+ require.Nil(t, repo.modelStream)
+}
+
+func TestDashboardModelStatsInvalidRequestType(t *testing.T) {
+ repo := &dashboardUsageRepoCapture{}
+ router := newDashboardRequestTypeTestRouter(repo)
+
+ req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/models?request_type=bad", nil)
+ rec := httptest.NewRecorder()
+ router.ServeHTTP(rec, req)
+
+ require.Equal(t, http.StatusBadRequest, rec.Code)
+}
+
+func TestDashboardModelStatsInvalidStream(t *testing.T) {
+ repo := &dashboardUsageRepoCapture{}
+ router := newDashboardRequestTypeTestRouter(repo)
+
+ req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/models?stream=bad", nil)
+ rec := httptest.NewRecorder()
+ router.ServeHTTP(rec, req)
+
+ require.Equal(t, http.StatusBadRequest, rec.Code)
+}
diff --git a/backend/internal/handler/admin/data_management_handler.go b/backend/internal/handler/admin/data_management_handler.go
new file mode 100644
index 00000000..02fc766f
--- /dev/null
+++ b/backend/internal/handler/admin/data_management_handler.go
@@ -0,0 +1,545 @@
+package admin
+
+import (
+ "context"
+ "strconv"
+ "strings"
+
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/response"
+ middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+
+ "github.com/gin-gonic/gin"
+)
+
+type DataManagementHandler struct {
+ dataManagementService dataManagementService
+}
+
+func NewDataManagementHandler(dataManagementService *service.DataManagementService) *DataManagementHandler {
+ return &DataManagementHandler{dataManagementService: dataManagementService}
+}
+
+type dataManagementService interface {
+ GetConfig(ctx context.Context) (service.DataManagementConfig, error)
+ UpdateConfig(ctx context.Context, cfg service.DataManagementConfig) (service.DataManagementConfig, error)
+ ValidateS3(ctx context.Context, cfg service.DataManagementS3Config) (service.DataManagementTestS3Result, error)
+ CreateBackupJob(ctx context.Context, input service.DataManagementCreateBackupJobInput) (service.DataManagementBackupJob, error)
+ ListSourceProfiles(ctx context.Context, sourceType string) ([]service.DataManagementSourceProfile, error)
+ CreateSourceProfile(ctx context.Context, input service.DataManagementCreateSourceProfileInput) (service.DataManagementSourceProfile, error)
+ UpdateSourceProfile(ctx context.Context, input service.DataManagementUpdateSourceProfileInput) (service.DataManagementSourceProfile, error)
+ DeleteSourceProfile(ctx context.Context, sourceType, profileID string) error
+ SetActiveSourceProfile(ctx context.Context, sourceType, profileID string) (service.DataManagementSourceProfile, error)
+ ListS3Profiles(ctx context.Context) ([]service.DataManagementS3Profile, error)
+ CreateS3Profile(ctx context.Context, input service.DataManagementCreateS3ProfileInput) (service.DataManagementS3Profile, error)
+ UpdateS3Profile(ctx context.Context, input service.DataManagementUpdateS3ProfileInput) (service.DataManagementS3Profile, error)
+ DeleteS3Profile(ctx context.Context, profileID string) error
+ SetActiveS3Profile(ctx context.Context, profileID string) (service.DataManagementS3Profile, error)
+ ListBackupJobs(ctx context.Context, input service.DataManagementListBackupJobsInput) (service.DataManagementListBackupJobsResult, error)
+ GetBackupJob(ctx context.Context, jobID string) (service.DataManagementBackupJob, error)
+ EnsureAgentEnabled(ctx context.Context) error
+ GetAgentHealth(ctx context.Context) service.DataManagementAgentHealth
+}
+
+type TestS3ConnectionRequest struct {
+ Endpoint string `json:"endpoint"`
+ Region string `json:"region" binding:"required"`
+ Bucket string `json:"bucket" binding:"required"`
+ AccessKeyID string `json:"access_key_id"`
+ SecretAccessKey string `json:"secret_access_key"`
+ Prefix string `json:"prefix"`
+ ForcePathStyle bool `json:"force_path_style"`
+ UseSSL bool `json:"use_ssl"`
+}
+
+type CreateBackupJobRequest struct {
+ BackupType string `json:"backup_type" binding:"required,oneof=postgres redis full"`
+ UploadToS3 bool `json:"upload_to_s3"`
+ S3ProfileID string `json:"s3_profile_id"`
+ PostgresID string `json:"postgres_profile_id"`
+ RedisID string `json:"redis_profile_id"`
+ IdempotencyKey string `json:"idempotency_key"`
+}
+
+type CreateSourceProfileRequest struct {
+ ProfileID string `json:"profile_id" binding:"required"`
+ Name string `json:"name" binding:"required"`
+ Config service.DataManagementSourceConfig `json:"config" binding:"required"`
+ SetActive bool `json:"set_active"`
+}
+
+type UpdateSourceProfileRequest struct {
+ Name string `json:"name" binding:"required"`
+ Config service.DataManagementSourceConfig `json:"config" binding:"required"`
+}
+
+type CreateS3ProfileRequest struct {
+ ProfileID string `json:"profile_id" binding:"required"`
+ Name string `json:"name" binding:"required"`
+ Enabled bool `json:"enabled"`
+ Endpoint string `json:"endpoint"`
+ Region string `json:"region"`
+ Bucket string `json:"bucket"`
+ AccessKeyID string `json:"access_key_id"`
+ SecretAccessKey string `json:"secret_access_key"`
+ Prefix string `json:"prefix"`
+ ForcePathStyle bool `json:"force_path_style"`
+ UseSSL bool `json:"use_ssl"`
+ SetActive bool `json:"set_active"`
+}
+
+type UpdateS3ProfileRequest struct {
+ Name string `json:"name" binding:"required"`
+ Enabled bool `json:"enabled"`
+ Endpoint string `json:"endpoint"`
+ Region string `json:"region"`
+ Bucket string `json:"bucket"`
+ AccessKeyID string `json:"access_key_id"`
+ SecretAccessKey string `json:"secret_access_key"`
+ Prefix string `json:"prefix"`
+ ForcePathStyle bool `json:"force_path_style"`
+ UseSSL bool `json:"use_ssl"`
+}
+
+func (h *DataManagementHandler) GetAgentHealth(c *gin.Context) {
+ health := h.getAgentHealth(c)
+ payload := gin.H{
+ "enabled": health.Enabled,
+ "reason": health.Reason,
+ "socket_path": health.SocketPath,
+ }
+ if health.Agent != nil {
+ payload["agent"] = gin.H{
+ "status": health.Agent.Status,
+ "version": health.Agent.Version,
+ "uptime_seconds": health.Agent.UptimeSeconds,
+ }
+ }
+ response.Success(c, payload)
+}
+
+func (h *DataManagementHandler) GetConfig(c *gin.Context) {
+ if !h.requireAgentEnabled(c) {
+ return
+ }
+ cfg, err := h.dataManagementService.GetConfig(c.Request.Context())
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, cfg)
+}
+
+func (h *DataManagementHandler) UpdateConfig(c *gin.Context) {
+ var req service.DataManagementConfig
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ if !h.requireAgentEnabled(c) {
+ return
+ }
+ cfg, err := h.dataManagementService.UpdateConfig(c.Request.Context(), req)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, cfg)
+}
+
+func (h *DataManagementHandler) TestS3(c *gin.Context) {
+ var req TestS3ConnectionRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ if !h.requireAgentEnabled(c) {
+ return
+ }
+ result, err := h.dataManagementService.ValidateS3(c.Request.Context(), service.DataManagementS3Config{
+ Enabled: true,
+ Endpoint: req.Endpoint,
+ Region: req.Region,
+ Bucket: req.Bucket,
+ AccessKeyID: req.AccessKeyID,
+ SecretAccessKey: req.SecretAccessKey,
+ Prefix: req.Prefix,
+ ForcePathStyle: req.ForcePathStyle,
+ UseSSL: req.UseSSL,
+ })
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, gin.H{"ok": result.OK, "message": result.Message})
+}
+
+func (h *DataManagementHandler) CreateBackupJob(c *gin.Context) {
+ var req CreateBackupJobRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ req.IdempotencyKey = normalizeBackupIdempotencyKey(c.GetHeader("X-Idempotency-Key"), req.IdempotencyKey)
+ if !h.requireAgentEnabled(c) {
+ return
+ }
+
+ triggeredBy := "admin:unknown"
+ if subject, ok := middleware2.GetAuthSubjectFromContext(c); ok {
+ triggeredBy = "admin:" + strconv.FormatInt(subject.UserID, 10)
+ }
+ job, err := h.dataManagementService.CreateBackupJob(c.Request.Context(), service.DataManagementCreateBackupJobInput{
+ BackupType: req.BackupType,
+ UploadToS3: req.UploadToS3,
+ S3ProfileID: req.S3ProfileID,
+ PostgresID: req.PostgresID,
+ RedisID: req.RedisID,
+ TriggeredBy: triggeredBy,
+ IdempotencyKey: req.IdempotencyKey,
+ })
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, gin.H{"job_id": job.JobID, "status": job.Status})
+}
+
+func (h *DataManagementHandler) ListSourceProfiles(c *gin.Context) {
+ sourceType := strings.TrimSpace(c.Param("source_type"))
+ if sourceType == "" {
+ response.BadRequest(c, "Invalid source_type")
+ return
+ }
+ if sourceType != "postgres" && sourceType != "redis" {
+ response.BadRequest(c, "source_type must be postgres or redis")
+ return
+ }
+
+ if !h.requireAgentEnabled(c) {
+ return
+ }
+ items, err := h.dataManagementService.ListSourceProfiles(c.Request.Context(), sourceType)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, gin.H{"items": items})
+}
+
+func (h *DataManagementHandler) CreateSourceProfile(c *gin.Context) {
+ sourceType := strings.TrimSpace(c.Param("source_type"))
+ if sourceType != "postgres" && sourceType != "redis" {
+ response.BadRequest(c, "source_type must be postgres or redis")
+ return
+ }
+
+ var req CreateSourceProfileRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ if !h.requireAgentEnabled(c) {
+ return
+ }
+ profile, err := h.dataManagementService.CreateSourceProfile(c.Request.Context(), service.DataManagementCreateSourceProfileInput{
+ SourceType: sourceType,
+ ProfileID: req.ProfileID,
+ Name: req.Name,
+ Config: req.Config,
+ SetActive: req.SetActive,
+ })
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, profile)
+}
+
+func (h *DataManagementHandler) UpdateSourceProfile(c *gin.Context) {
+ sourceType := strings.TrimSpace(c.Param("source_type"))
+ if sourceType != "postgres" && sourceType != "redis" {
+ response.BadRequest(c, "source_type must be postgres or redis")
+ return
+ }
+ profileID := strings.TrimSpace(c.Param("profile_id"))
+ if profileID == "" {
+ response.BadRequest(c, "Invalid profile_id")
+ return
+ }
+
+ var req UpdateSourceProfileRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ if !h.requireAgentEnabled(c) {
+ return
+ }
+ profile, err := h.dataManagementService.UpdateSourceProfile(c.Request.Context(), service.DataManagementUpdateSourceProfileInput{
+ SourceType: sourceType,
+ ProfileID: profileID,
+ Name: req.Name,
+ Config: req.Config,
+ })
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, profile)
+}
+
+func (h *DataManagementHandler) DeleteSourceProfile(c *gin.Context) {
+ sourceType := strings.TrimSpace(c.Param("source_type"))
+ if sourceType != "postgres" && sourceType != "redis" {
+ response.BadRequest(c, "source_type must be postgres or redis")
+ return
+ }
+ profileID := strings.TrimSpace(c.Param("profile_id"))
+ if profileID == "" {
+ response.BadRequest(c, "Invalid profile_id")
+ return
+ }
+
+ if !h.requireAgentEnabled(c) {
+ return
+ }
+ if err := h.dataManagementService.DeleteSourceProfile(c.Request.Context(), sourceType, profileID); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, gin.H{"deleted": true})
+}
+
+func (h *DataManagementHandler) SetActiveSourceProfile(c *gin.Context) {
+ sourceType := strings.TrimSpace(c.Param("source_type"))
+ if sourceType != "postgres" && sourceType != "redis" {
+ response.BadRequest(c, "source_type must be postgres or redis")
+ return
+ }
+ profileID := strings.TrimSpace(c.Param("profile_id"))
+ if profileID == "" {
+ response.BadRequest(c, "Invalid profile_id")
+ return
+ }
+
+ if !h.requireAgentEnabled(c) {
+ return
+ }
+ profile, err := h.dataManagementService.SetActiveSourceProfile(c.Request.Context(), sourceType, profileID)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, profile)
+}
+
+func (h *DataManagementHandler) ListS3Profiles(c *gin.Context) {
+ if !h.requireAgentEnabled(c) {
+ return
+ }
+
+ items, err := h.dataManagementService.ListS3Profiles(c.Request.Context())
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, gin.H{"items": items})
+}
+
+func (h *DataManagementHandler) CreateS3Profile(c *gin.Context) {
+ var req CreateS3ProfileRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ if !h.requireAgentEnabled(c) {
+ return
+ }
+
+ profile, err := h.dataManagementService.CreateS3Profile(c.Request.Context(), service.DataManagementCreateS3ProfileInput{
+ ProfileID: req.ProfileID,
+ Name: req.Name,
+ SetActive: req.SetActive,
+ S3: service.DataManagementS3Config{
+ Enabled: req.Enabled,
+ Endpoint: req.Endpoint,
+ Region: req.Region,
+ Bucket: req.Bucket,
+ AccessKeyID: req.AccessKeyID,
+ SecretAccessKey: req.SecretAccessKey,
+ Prefix: req.Prefix,
+ ForcePathStyle: req.ForcePathStyle,
+ UseSSL: req.UseSSL,
+ },
+ })
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, profile)
+}
+
+func (h *DataManagementHandler) UpdateS3Profile(c *gin.Context) {
+ var req UpdateS3ProfileRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ profileID := strings.TrimSpace(c.Param("profile_id"))
+ if profileID == "" {
+ response.BadRequest(c, "Invalid profile_id")
+ return
+ }
+
+ if !h.requireAgentEnabled(c) {
+ return
+ }
+
+ profile, err := h.dataManagementService.UpdateS3Profile(c.Request.Context(), service.DataManagementUpdateS3ProfileInput{
+ ProfileID: profileID,
+ Name: req.Name,
+ S3: service.DataManagementS3Config{
+ Enabled: req.Enabled,
+ Endpoint: req.Endpoint,
+ Region: req.Region,
+ Bucket: req.Bucket,
+ AccessKeyID: req.AccessKeyID,
+ SecretAccessKey: req.SecretAccessKey,
+ Prefix: req.Prefix,
+ ForcePathStyle: req.ForcePathStyle,
+ UseSSL: req.UseSSL,
+ },
+ })
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, profile)
+}
+
+func (h *DataManagementHandler) DeleteS3Profile(c *gin.Context) {
+ profileID := strings.TrimSpace(c.Param("profile_id"))
+ if profileID == "" {
+ response.BadRequest(c, "Invalid profile_id")
+ return
+ }
+
+ if !h.requireAgentEnabled(c) {
+ return
+ }
+ if err := h.dataManagementService.DeleteS3Profile(c.Request.Context(), profileID); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, gin.H{"deleted": true})
+}
+
+func (h *DataManagementHandler) SetActiveS3Profile(c *gin.Context) {
+ profileID := strings.TrimSpace(c.Param("profile_id"))
+ if profileID == "" {
+ response.BadRequest(c, "Invalid profile_id")
+ return
+ }
+
+ if !h.requireAgentEnabled(c) {
+ return
+ }
+ profile, err := h.dataManagementService.SetActiveS3Profile(c.Request.Context(), profileID)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, profile)
+}
+
+func (h *DataManagementHandler) ListBackupJobs(c *gin.Context) {
+ if !h.requireAgentEnabled(c) {
+ return
+ }
+
+ pageSize := int32(20)
+ if raw := strings.TrimSpace(c.Query("page_size")); raw != "" {
+ v, err := strconv.Atoi(raw)
+ if err != nil || v <= 0 {
+ response.BadRequest(c, "Invalid page_size")
+ return
+ }
+ pageSize = int32(v)
+ }
+
+ result, err := h.dataManagementService.ListBackupJobs(c.Request.Context(), service.DataManagementListBackupJobsInput{
+ PageSize: pageSize,
+ PageToken: c.Query("page_token"),
+ Status: c.Query("status"),
+ BackupType: c.Query("backup_type"),
+ })
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, result)
+}
+
+func (h *DataManagementHandler) GetBackupJob(c *gin.Context) {
+ jobID := strings.TrimSpace(c.Param("job_id"))
+ if jobID == "" {
+ response.BadRequest(c, "Invalid backup job ID")
+ return
+ }
+
+ if !h.requireAgentEnabled(c) {
+ return
+ }
+ job, err := h.dataManagementService.GetBackupJob(c.Request.Context(), jobID)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, job)
+}
+
+func (h *DataManagementHandler) requireAgentEnabled(c *gin.Context) bool {
+ if h.dataManagementService == nil {
+ err := infraerrors.ServiceUnavailable(
+ service.DataManagementAgentUnavailableReason,
+ "data management agent service is not configured",
+ ).WithMetadata(map[string]string{"socket_path": service.DefaultDataManagementAgentSocketPath})
+ response.ErrorFrom(c, err)
+ return false
+ }
+
+ if err := h.dataManagementService.EnsureAgentEnabled(c.Request.Context()); err != nil {
+ response.ErrorFrom(c, err)
+ return false
+ }
+
+ return true
+}
+
+func (h *DataManagementHandler) getAgentHealth(c *gin.Context) service.DataManagementAgentHealth {
+ if h.dataManagementService == nil {
+ return service.DataManagementAgentHealth{
+ Enabled: false,
+ Reason: service.DataManagementAgentUnavailableReason,
+ SocketPath: service.DefaultDataManagementAgentSocketPath,
+ }
+ }
+ return h.dataManagementService.GetAgentHealth(c.Request.Context())
+}
+
+func normalizeBackupIdempotencyKey(headerValue, bodyValue string) string {
+ headerKey := strings.TrimSpace(headerValue)
+ if headerKey != "" {
+ return headerKey
+ }
+ return strings.TrimSpace(bodyValue)
+}
diff --git a/backend/internal/handler/admin/data_management_handler_test.go b/backend/internal/handler/admin/data_management_handler_test.go
new file mode 100644
index 00000000..ce8ee835
--- /dev/null
+++ b/backend/internal/handler/admin/data_management_handler_test.go
@@ -0,0 +1,78 @@
+package admin
+
+import (
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "path/filepath"
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+)
+
+type apiEnvelope struct {
+ Code int `json:"code"`
+ Message string `json:"message"`
+ Reason string `json:"reason"`
+ Data json.RawMessage `json:"data"`
+}
+
+func TestDataManagementHandler_AgentHealthAlways200(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ svc := service.NewDataManagementServiceWithOptions(filepath.Join(t.TempDir(), "missing.sock"), 50*time.Millisecond)
+ h := NewDataManagementHandler(svc)
+
+ r := gin.New()
+ r.GET("/api/v1/admin/data-management/agent/health", h.GetAgentHealth)
+
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/data-management/agent/health", nil)
+ r.ServeHTTP(rec, req)
+
+ require.Equal(t, http.StatusOK, rec.Code)
+
+ var envelope apiEnvelope
+ require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &envelope))
+ require.Equal(t, 0, envelope.Code)
+
+ var data struct {
+ Enabled bool `json:"enabled"`
+ Reason string `json:"reason"`
+ SocketPath string `json:"socket_path"`
+ }
+ require.NoError(t, json.Unmarshal(envelope.Data, &data))
+ require.False(t, data.Enabled)
+ require.Equal(t, service.DataManagementDeprecatedReason, data.Reason)
+ require.Equal(t, svc.SocketPath(), data.SocketPath)
+}
+
+func TestDataManagementHandler_NonHealthRouteReturns503WhenDisabled(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ svc := service.NewDataManagementServiceWithOptions(filepath.Join(t.TempDir(), "missing.sock"), 50*time.Millisecond)
+ h := NewDataManagementHandler(svc)
+
+ r := gin.New()
+ r.GET("/api/v1/admin/data-management/config", h.GetConfig)
+
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/data-management/config", nil)
+ r.ServeHTTP(rec, req)
+
+ require.Equal(t, http.StatusServiceUnavailable, rec.Code)
+
+ var envelope apiEnvelope
+ require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &envelope))
+ require.Equal(t, http.StatusServiceUnavailable, envelope.Code)
+ require.Equal(t, service.DataManagementDeprecatedReason, envelope.Reason)
+}
+
+func TestNormalizeBackupIdempotencyKey(t *testing.T) {
+ require.Equal(t, "from-header", normalizeBackupIdempotencyKey("from-header", "from-body"))
+ require.Equal(t, "from-body", normalizeBackupIdempotencyKey(" ", " from-body "))
+ require.Equal(t, "", normalizeBackupIdempotencyKey("", ""))
+}
diff --git a/backend/internal/handler/admin/group_handler.go b/backend/internal/handler/admin/group_handler.go
index 25ff3c96..1edf4dcc 100644
--- a/backend/internal/handler/admin/group_handler.go
+++ b/backend/internal/handler/admin/group_handler.go
@@ -51,6 +51,8 @@ type CreateGroupRequest struct {
MCPXMLInject *bool `json:"mcp_xml_inject"`
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes []string `json:"supported_model_scopes"`
+ // Sora 存储配额
+ SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"`
// 从指定分组复制账号(创建后自动绑定)
CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"`
}
@@ -84,6 +86,8 @@ type UpdateGroupRequest struct {
MCPXMLInject *bool `json:"mcp_xml_inject"`
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes *[]string `json:"supported_model_scopes"`
+ // Sora 存储配额
+ SoraStorageQuotaBytes *int64 `json:"sora_storage_quota_bytes"`
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"`
}
@@ -198,6 +202,7 @@ func (h *GroupHandler) Create(c *gin.Context) {
ModelRoutingEnabled: req.ModelRoutingEnabled,
MCPXMLInject: req.MCPXMLInject,
SupportedModelScopes: req.SupportedModelScopes,
+ SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs,
})
if err != nil {
@@ -248,6 +253,7 @@ func (h *GroupHandler) Update(c *gin.Context) {
ModelRoutingEnabled: req.ModelRoutingEnabled,
MCPXMLInject: req.MCPXMLInject,
SupportedModelScopes: req.SupportedModelScopes,
+ SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs,
})
if err != nil {
diff --git a/backend/internal/handler/admin/openai_oauth_handler.go b/backend/internal/handler/admin/openai_oauth_handler.go
index cf43f89e..5d354fd3 100644
--- a/backend/internal/handler/admin/openai_oauth_handler.go
+++ b/backend/internal/handler/admin/openai_oauth_handler.go
@@ -5,6 +5,7 @@ import (
"strings"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -47,7 +48,12 @@ func (h *OpenAIOAuthHandler) GenerateAuthURL(c *gin.Context) {
req = OpenAIGenerateAuthURLRequest{}
}
- result, err := h.openaiOAuthService.GenerateAuthURL(c.Request.Context(), req.ProxyID, req.RedirectURI)
+ result, err := h.openaiOAuthService.GenerateAuthURL(
+ c.Request.Context(),
+ req.ProxyID,
+ req.RedirectURI,
+ oauthPlatformFromPath(c),
+ )
if err != nil {
response.ErrorFrom(c, err)
return
@@ -123,7 +129,14 @@ func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) {
}
}
- tokenInfo, err := h.openaiOAuthService.RefreshTokenWithClientID(c.Request.Context(), refreshToken, proxyURL, strings.TrimSpace(req.ClientID))
+ // 未指定 client_id 时,根据请求路径平台自动设置默认值,避免 repository 层盲猜
+ clientID := strings.TrimSpace(req.ClientID)
+ if clientID == "" {
+ platform := oauthPlatformFromPath(c)
+ clientID, _ = openai.OAuthClientConfigByPlatform(platform)
+ }
+
+ tokenInfo, err := h.openaiOAuthService.RefreshTokenWithClientID(c.Request.Context(), refreshToken, proxyURL, clientID)
if err != nil {
response.ErrorFrom(c, err)
return
diff --git a/backend/internal/handler/admin/ops_ws_handler.go b/backend/internal/handler/admin/ops_ws_handler.go
index c030d303..75fd7ea0 100644
--- a/backend/internal/handler/admin/ops_ws_handler.go
+++ b/backend/internal/handler/admin/ops_ws_handler.go
@@ -62,7 +62,8 @@ const (
)
var wsConnCount atomic.Int32
-var wsConnCountByIP sync.Map // map[string]*atomic.Int32
+var wsConnCountByIPMu sync.Mutex
+var wsConnCountByIP = make(map[string]int32)
const qpsWSIdleStopDelay = 30 * time.Second
@@ -389,42 +390,31 @@ func tryAcquireOpsWSIPSlot(clientIP string, limit int32) bool {
if strings.TrimSpace(clientIP) == "" || limit <= 0 {
return true
}
-
- v, _ := wsConnCountByIP.LoadOrStore(clientIP, &atomic.Int32{})
- counter, ok := v.(*atomic.Int32)
- if !ok {
+ wsConnCountByIPMu.Lock()
+ defer wsConnCountByIPMu.Unlock()
+ current := wsConnCountByIP[clientIP]
+ if current >= limit {
return false
}
-
- for {
- current := counter.Load()
- if current >= limit {
- return false
- }
- if counter.CompareAndSwap(current, current+1) {
- return true
- }
- }
+ wsConnCountByIP[clientIP] = current + 1
+ return true
}
func releaseOpsWSIPSlot(clientIP string) {
if strings.TrimSpace(clientIP) == "" {
return
}
-
- v, ok := wsConnCountByIP.Load(clientIP)
+ wsConnCountByIPMu.Lock()
+ defer wsConnCountByIPMu.Unlock()
+ current, ok := wsConnCountByIP[clientIP]
if !ok {
return
}
- counter, ok := v.(*atomic.Int32)
- if !ok {
+ if current <= 1 {
+ delete(wsConnCountByIP, clientIP)
return
}
- next := counter.Add(-1)
- if next <= 0 {
- // Best-effort cleanup; safe even if a new slot was acquired concurrently.
- wsConnCountByIP.Delete(clientIP)
- }
+ wsConnCountByIP[clientIP] = current - 1
}
func handleQPSWebSocket(parentCtx context.Context, conn *websocket.Conn) {
diff --git a/backend/internal/handler/admin/proxy_handler.go b/backend/internal/handler/admin/proxy_handler.go
index 9fd187fc..e8ae0ce2 100644
--- a/backend/internal/handler/admin/proxy_handler.go
+++ b/backend/internal/handler/admin/proxy_handler.go
@@ -64,9 +64,9 @@ func (h *ProxyHandler) List(c *gin.Context) {
return
}
- out := make([]dto.ProxyWithAccountCount, 0, len(proxies))
+ out := make([]dto.AdminProxyWithAccountCount, 0, len(proxies))
for i := range proxies {
- out = append(out, *dto.ProxyWithAccountCountFromService(&proxies[i]))
+ out = append(out, *dto.ProxyWithAccountCountFromServiceAdmin(&proxies[i]))
}
response.Paginated(c, out, total, page, pageSize)
}
@@ -83,9 +83,9 @@ func (h *ProxyHandler) GetAll(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
- out := make([]dto.ProxyWithAccountCount, 0, len(proxies))
+ out := make([]dto.AdminProxyWithAccountCount, 0, len(proxies))
for i := range proxies {
- out = append(out, *dto.ProxyWithAccountCountFromService(&proxies[i]))
+ out = append(out, *dto.ProxyWithAccountCountFromServiceAdmin(&proxies[i]))
}
response.Success(c, out)
return
@@ -97,9 +97,9 @@ func (h *ProxyHandler) GetAll(c *gin.Context) {
return
}
- out := make([]dto.Proxy, 0, len(proxies))
+ out := make([]dto.AdminProxy, 0, len(proxies))
for i := range proxies {
- out = append(out, *dto.ProxyFromService(&proxies[i]))
+ out = append(out, *dto.ProxyFromServiceAdmin(&proxies[i]))
}
response.Success(c, out)
}
@@ -119,7 +119,7 @@ func (h *ProxyHandler) GetByID(c *gin.Context) {
return
}
- response.Success(c, dto.ProxyFromService(proxy))
+ response.Success(c, dto.ProxyFromServiceAdmin(proxy))
}
// Create handles creating a new proxy
@@ -143,7 +143,7 @@ func (h *ProxyHandler) Create(c *gin.Context) {
if err != nil {
return nil, err
}
- return dto.ProxyFromService(proxy), nil
+ return dto.ProxyFromServiceAdmin(proxy), nil
})
}
@@ -176,7 +176,7 @@ func (h *ProxyHandler) Update(c *gin.Context) {
return
}
- response.Success(c, dto.ProxyFromService(proxy))
+ response.Success(c, dto.ProxyFromServiceAdmin(proxy))
}
// Delete handles deleting a proxy
diff --git a/backend/internal/handler/admin/redeem_handler.go b/backend/internal/handler/admin/redeem_handler.go
index 7073061d..0a932ee9 100644
--- a/backend/internal/handler/admin/redeem_handler.go
+++ b/backend/internal/handler/admin/redeem_handler.go
@@ -4,11 +4,13 @@ import (
"bytes"
"context"
"encoding/csv"
+ "errors"
"fmt"
"strconv"
"strings"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -17,13 +19,15 @@ import (
// RedeemHandler handles admin redeem code management
type RedeemHandler struct {
- adminService service.AdminService
+ adminService service.AdminService
+ redeemService *service.RedeemService
}
// NewRedeemHandler creates a new admin redeem handler
-func NewRedeemHandler(adminService service.AdminService) *RedeemHandler {
+func NewRedeemHandler(adminService service.AdminService, redeemService *service.RedeemService) *RedeemHandler {
return &RedeemHandler{
- adminService: adminService,
+ adminService: adminService,
+ redeemService: redeemService,
}
}
@@ -36,6 +40,15 @@ type GenerateRedeemCodesRequest struct {
ValidityDays int `json:"validity_days" binding:"omitempty,max=36500"` // 订阅类型使用,默认30天,最大100年
}
+// CreateAndRedeemCodeRequest represents creating a fixed code and redeeming it for a target user.
+type CreateAndRedeemCodeRequest struct {
+ Code string `json:"code" binding:"required,min=3,max=128"`
+ Type string `json:"type" binding:"required,oneof=balance concurrency subscription invitation"`
+ Value float64 `json:"value" binding:"required,gt=0"`
+ UserID int64 `json:"user_id" binding:"required,gt=0"`
+ Notes string `json:"notes"`
+}
+
// List handles listing all redeem codes with pagination
// GET /api/v1/admin/redeem-codes
func (h *RedeemHandler) List(c *gin.Context) {
@@ -109,6 +122,81 @@ func (h *RedeemHandler) Generate(c *gin.Context) {
})
}
+// CreateAndRedeem creates a fixed redeem code and redeems it for a target user in one step.
+// POST /api/v1/admin/redeem-codes/create-and-redeem
+func (h *RedeemHandler) CreateAndRedeem(c *gin.Context) {
+ if h.redeemService == nil {
+ response.InternalError(c, "redeem service not configured")
+ return
+ }
+
+ var req CreateAndRedeemCodeRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+ req.Code = strings.TrimSpace(req.Code)
+
+ executeAdminIdempotentJSON(c, "admin.redeem_codes.create_and_redeem", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
+ existing, err := h.redeemService.GetByCode(ctx, req.Code)
+ if err == nil {
+ return h.resolveCreateAndRedeemExisting(ctx, existing, req.UserID)
+ }
+ if !errors.Is(err, service.ErrRedeemCodeNotFound) {
+ return nil, err
+ }
+
+ createErr := h.redeemService.CreateCode(ctx, &service.RedeemCode{
+ Code: req.Code,
+ Type: req.Type,
+ Value: req.Value,
+ Status: service.StatusUnused,
+ Notes: req.Notes,
+ })
+ if createErr != nil {
+ // Unique code race: if code now exists, use idempotent semantics by used_by.
+ existingAfterCreateErr, getErr := h.redeemService.GetByCode(ctx, req.Code)
+ if getErr == nil {
+ return h.resolveCreateAndRedeemExisting(ctx, existingAfterCreateErr, req.UserID)
+ }
+ return nil, createErr
+ }
+
+ redeemed, redeemErr := h.redeemService.Redeem(ctx, req.UserID, req.Code)
+ if redeemErr != nil {
+ return nil, redeemErr
+ }
+ return gin.H{"redeem_code": dto.RedeemCodeFromServiceAdmin(redeemed)}, nil
+ })
+}
+
+func (h *RedeemHandler) resolveCreateAndRedeemExisting(ctx context.Context, existing *service.RedeemCode, userID int64) (any, error) {
+ if existing == nil {
+ return nil, infraerrors.Conflict("REDEEM_CODE_CONFLICT", "redeem code conflict")
+ }
+
+ // If previous run created the code but crashed before redeem, redeem it now.
+ if existing.CanUse() {
+ redeemed, err := h.redeemService.Redeem(ctx, userID, existing.Code)
+ if err == nil {
+ return gin.H{"redeem_code": dto.RedeemCodeFromServiceAdmin(redeemed)}, nil
+ }
+ if !errors.Is(err, service.ErrRedeemCodeUsed) {
+ return nil, err
+ }
+ latest, getErr := h.redeemService.GetByCode(ctx, existing.Code)
+ if getErr == nil {
+ existing = latest
+ }
+ }
+
+ if existing.UsedBy != nil && *existing.UsedBy == userID {
+ return gin.H{"redeem_code": dto.RedeemCodeFromServiceAdmin(existing)}, nil
+ }
+
+ return nil, infraerrors.Conflict("REDEEM_CODE_CONFLICT", "redeem code already used by another user")
+}
+
// Delete handles deleting a redeem code
// DELETE /api/v1/admin/redeem-codes/:id
func (h *RedeemHandler) Delete(c *gin.Context) {
diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go
index 1e723ee5..e7da042c 100644
--- a/backend/internal/handler/admin/setting_handler.go
+++ b/backend/internal/handler/admin/setting_handler.go
@@ -1,7 +1,10 @@
package admin
import (
+ "fmt"
"log"
+ "net/http"
+ "regexp"
"strings"
"time"
@@ -14,21 +17,26 @@ import (
"github.com/gin-gonic/gin"
)
+// semverPattern 预编译 semver 格式校验正则
+var semverPattern = regexp.MustCompile(`^\d+\.\d+\.\d+$`)
+
// SettingHandler 系统设置处理器
type SettingHandler struct {
settingService *service.SettingService
emailService *service.EmailService
turnstileService *service.TurnstileService
opsService *service.OpsService
+ soraS3Storage *service.SoraS3Storage
}
// NewSettingHandler 创建系统设置处理器
-func NewSettingHandler(settingService *service.SettingService, emailService *service.EmailService, turnstileService *service.TurnstileService, opsService *service.OpsService) *SettingHandler {
+func NewSettingHandler(settingService *service.SettingService, emailService *service.EmailService, turnstileService *service.TurnstileService, opsService *service.OpsService, soraS3Storage *service.SoraS3Storage) *SettingHandler {
return &SettingHandler{
settingService: settingService,
emailService: emailService,
turnstileService: turnstileService,
opsService: opsService,
+ soraS3Storage: soraS3Storage,
}
}
@@ -43,6 +51,13 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
// Check if ops monitoring is enabled (respects config.ops.enabled)
opsEnabled := h.opsService != nil && h.opsService.IsMonitoringEnabled(c.Request.Context())
+ defaultSubscriptions := make([]dto.DefaultSubscriptionSetting, 0, len(settings.DefaultSubscriptions))
+ for _, sub := range settings.DefaultSubscriptions {
+ defaultSubscriptions = append(defaultSubscriptions, dto.DefaultSubscriptionSetting{
+ GroupID: sub.GroupID,
+ ValidityDays: sub.ValidityDays,
+ })
+ }
response.Success(c, dto.SystemSettings{
RegistrationEnabled: settings.RegistrationEnabled,
@@ -76,8 +91,10 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
HideCcsImportButton: settings.HideCcsImportButton,
PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
+ SoraClientEnabled: settings.SoraClientEnabled,
DefaultConcurrency: settings.DefaultConcurrency,
DefaultBalance: settings.DefaultBalance,
+ DefaultSubscriptions: defaultSubscriptions,
EnableModelFallback: settings.EnableModelFallback,
FallbackModelAnthropic: settings.FallbackModelAnthropic,
FallbackModelOpenAI: settings.FallbackModelOpenAI,
@@ -89,6 +106,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
OpsRealtimeMonitoringEnabled: settings.OpsRealtimeMonitoringEnabled,
OpsQueryModeDefault: settings.OpsQueryModeDefault,
OpsMetricsIntervalSeconds: settings.OpsMetricsIntervalSeconds,
+ MinClaudeCodeVersion: settings.MinClaudeCodeVersion,
})
}
@@ -133,10 +151,12 @@ type UpdateSettingsRequest struct {
HideCcsImportButton bool `json:"hide_ccs_import_button"`
PurchaseSubscriptionEnabled *bool `json:"purchase_subscription_enabled"`
PurchaseSubscriptionURL *string `json:"purchase_subscription_url"`
+ SoraClientEnabled bool `json:"sora_client_enabled"`
// 默认配置
- DefaultConcurrency int `json:"default_concurrency"`
- DefaultBalance float64 `json:"default_balance"`
+ DefaultConcurrency int `json:"default_concurrency"`
+ DefaultBalance float64 `json:"default_balance"`
+ DefaultSubscriptions []dto.DefaultSubscriptionSetting `json:"default_subscriptions"`
// Model fallback configuration
EnableModelFallback bool `json:"enable_model_fallback"`
@@ -154,6 +174,8 @@ type UpdateSettingsRequest struct {
OpsRealtimeMonitoringEnabled *bool `json:"ops_realtime_monitoring_enabled"`
OpsQueryModeDefault *string `json:"ops_query_mode_default"`
OpsMetricsIntervalSeconds *int `json:"ops_metrics_interval_seconds"`
+
+ MinClaudeCodeVersion string `json:"min_claude_code_version"`
}
// UpdateSettings 更新系统设置
@@ -181,6 +203,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
if req.SMTPPort <= 0 {
req.SMTPPort = 587
}
+ req.DefaultSubscriptions = normalizeDefaultSubscriptions(req.DefaultSubscriptions)
// Turnstile 参数验证
if req.TurnstileEnabled {
@@ -287,6 +310,21 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
req.OpsMetricsIntervalSeconds = &v
}
+ defaultSubscriptions := make([]service.DefaultSubscriptionSetting, 0, len(req.DefaultSubscriptions))
+ for _, sub := range req.DefaultSubscriptions {
+ defaultSubscriptions = append(defaultSubscriptions, service.DefaultSubscriptionSetting{
+ GroupID: sub.GroupID,
+ ValidityDays: sub.ValidityDays,
+ })
+ }
+
+ // 验证最低版本号格式(空字符串=禁用,或合法 semver)
+ if req.MinClaudeCodeVersion != "" {
+ if !semverPattern.MatchString(req.MinClaudeCodeVersion) {
+ response.Error(c, http.StatusBadRequest, "min_claude_code_version must be empty or a valid semver (e.g. 2.1.63)")
+ return
+ }
+ }
settings := &service.SystemSettings{
RegistrationEnabled: req.RegistrationEnabled,
@@ -319,8 +357,10 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
HideCcsImportButton: req.HideCcsImportButton,
PurchaseSubscriptionEnabled: purchaseEnabled,
PurchaseSubscriptionURL: purchaseURL,
+ SoraClientEnabled: req.SoraClientEnabled,
DefaultConcurrency: req.DefaultConcurrency,
DefaultBalance: req.DefaultBalance,
+ DefaultSubscriptions: defaultSubscriptions,
EnableModelFallback: req.EnableModelFallback,
FallbackModelAnthropic: req.FallbackModelAnthropic,
FallbackModelOpenAI: req.FallbackModelOpenAI,
@@ -328,6 +368,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
FallbackModelAntigravity: req.FallbackModelAntigravity,
EnableIdentityPatch: req.EnableIdentityPatch,
IdentityPatchPrompt: req.IdentityPatchPrompt,
+ MinClaudeCodeVersion: req.MinClaudeCodeVersion,
OpsMonitoringEnabled: func() bool {
if req.OpsMonitoringEnabled != nil {
return *req.OpsMonitoringEnabled
@@ -367,6 +408,13 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
+ updatedDefaultSubscriptions := make([]dto.DefaultSubscriptionSetting, 0, len(updatedSettings.DefaultSubscriptions))
+ for _, sub := range updatedSettings.DefaultSubscriptions {
+ updatedDefaultSubscriptions = append(updatedDefaultSubscriptions, dto.DefaultSubscriptionSetting{
+ GroupID: sub.GroupID,
+ ValidityDays: sub.ValidityDays,
+ })
+ }
response.Success(c, dto.SystemSettings{
RegistrationEnabled: updatedSettings.RegistrationEnabled,
@@ -400,8 +448,10 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
HideCcsImportButton: updatedSettings.HideCcsImportButton,
PurchaseSubscriptionEnabled: updatedSettings.PurchaseSubscriptionEnabled,
PurchaseSubscriptionURL: updatedSettings.PurchaseSubscriptionURL,
+ SoraClientEnabled: updatedSettings.SoraClientEnabled,
DefaultConcurrency: updatedSettings.DefaultConcurrency,
DefaultBalance: updatedSettings.DefaultBalance,
+ DefaultSubscriptions: updatedDefaultSubscriptions,
EnableModelFallback: updatedSettings.EnableModelFallback,
FallbackModelAnthropic: updatedSettings.FallbackModelAnthropic,
FallbackModelOpenAI: updatedSettings.FallbackModelOpenAI,
@@ -413,6 +463,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
OpsRealtimeMonitoringEnabled: updatedSettings.OpsRealtimeMonitoringEnabled,
OpsQueryModeDefault: updatedSettings.OpsQueryModeDefault,
OpsMetricsIntervalSeconds: updatedSettings.OpsMetricsIntervalSeconds,
+ MinClaudeCodeVersion: updatedSettings.MinClaudeCodeVersion,
})
}
@@ -522,6 +573,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.DefaultBalance != after.DefaultBalance {
changed = append(changed, "default_balance")
}
+ if !equalDefaultSubscriptions(before.DefaultSubscriptions, after.DefaultSubscriptions) {
+ changed = append(changed, "default_subscriptions")
+ }
if before.EnableModelFallback != after.EnableModelFallback {
changed = append(changed, "enable_model_fallback")
}
@@ -555,9 +609,41 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.OpsMetricsIntervalSeconds != after.OpsMetricsIntervalSeconds {
changed = append(changed, "ops_metrics_interval_seconds")
}
+ if before.MinClaudeCodeVersion != after.MinClaudeCodeVersion {
+ changed = append(changed, "min_claude_code_version")
+ }
return changed
}
+func normalizeDefaultSubscriptions(input []dto.DefaultSubscriptionSetting) []dto.DefaultSubscriptionSetting {
+ if len(input) == 0 {
+ return nil
+ }
+ normalized := make([]dto.DefaultSubscriptionSetting, 0, len(input))
+ for _, item := range input {
+ if item.GroupID <= 0 || item.ValidityDays <= 0 {
+ continue
+ }
+ if item.ValidityDays > service.MaxValidityDays {
+ item.ValidityDays = service.MaxValidityDays
+ }
+ normalized = append(normalized, item)
+ }
+ return normalized
+}
+
+func equalDefaultSubscriptions(a, b []service.DefaultSubscriptionSetting) bool {
+ if len(a) != len(b) {
+ return false
+ }
+ for i := range a {
+ if a[i].GroupID != b[i].GroupID || a[i].ValidityDays != b[i].ValidityDays {
+ return false
+ }
+ }
+ return true
+}
+
// TestSMTPRequest 测试SMTP连接请求
type TestSMTPRequest struct {
SMTPHost string `json:"smtp_host" binding:"required"`
@@ -750,6 +836,384 @@ func (h *SettingHandler) GetStreamTimeoutSettings(c *gin.Context) {
})
}
+func toSoraS3SettingsDTO(settings *service.SoraS3Settings) dto.SoraS3Settings {
+ if settings == nil {
+ return dto.SoraS3Settings{}
+ }
+ return dto.SoraS3Settings{
+ Enabled: settings.Enabled,
+ Endpoint: settings.Endpoint,
+ Region: settings.Region,
+ Bucket: settings.Bucket,
+ AccessKeyID: settings.AccessKeyID,
+ SecretAccessKeyConfigured: settings.SecretAccessKeyConfigured,
+ Prefix: settings.Prefix,
+ ForcePathStyle: settings.ForcePathStyle,
+ CDNURL: settings.CDNURL,
+ DefaultStorageQuotaBytes: settings.DefaultStorageQuotaBytes,
+ }
+}
+
+func toSoraS3ProfileDTO(profile service.SoraS3Profile) dto.SoraS3Profile {
+ return dto.SoraS3Profile{
+ ProfileID: profile.ProfileID,
+ Name: profile.Name,
+ IsActive: profile.IsActive,
+ Enabled: profile.Enabled,
+ Endpoint: profile.Endpoint,
+ Region: profile.Region,
+ Bucket: profile.Bucket,
+ AccessKeyID: profile.AccessKeyID,
+ SecretAccessKeyConfigured: profile.SecretAccessKeyConfigured,
+ Prefix: profile.Prefix,
+ ForcePathStyle: profile.ForcePathStyle,
+ CDNURL: profile.CDNURL,
+ DefaultStorageQuotaBytes: profile.DefaultStorageQuotaBytes,
+ UpdatedAt: profile.UpdatedAt,
+ }
+}
+
+func validateSoraS3RequiredWhenEnabled(enabled bool, endpoint, bucket, accessKeyID, secretAccessKey string, hasStoredSecret bool) error {
+ if !enabled {
+ return nil
+ }
+ if strings.TrimSpace(endpoint) == "" {
+ return fmt.Errorf("S3 Endpoint is required when enabled")
+ }
+ if strings.TrimSpace(bucket) == "" {
+ return fmt.Errorf("S3 Bucket is required when enabled")
+ }
+ if strings.TrimSpace(accessKeyID) == "" {
+ return fmt.Errorf("S3 Access Key ID is required when enabled")
+ }
+ if strings.TrimSpace(secretAccessKey) != "" || hasStoredSecret {
+ return nil
+ }
+ return fmt.Errorf("S3 Secret Access Key is required when enabled")
+}
+
+func findSoraS3ProfileByID(items []service.SoraS3Profile, profileID string) *service.SoraS3Profile {
+ for idx := range items {
+ if items[idx].ProfileID == profileID {
+ return &items[idx]
+ }
+ }
+ return nil
+}
+
+// GetSoraS3Settings 获取 Sora S3 存储配置(兼容旧单配置接口)
+// GET /api/v1/admin/settings/sora-s3
+func (h *SettingHandler) GetSoraS3Settings(c *gin.Context) {
+ settings, err := h.settingService.GetSoraS3Settings(c.Request.Context())
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, toSoraS3SettingsDTO(settings))
+}
+
+// ListSoraS3Profiles 获取 Sora S3 多配置
+// GET /api/v1/admin/settings/sora-s3/profiles
+func (h *SettingHandler) ListSoraS3Profiles(c *gin.Context) {
+ result, err := h.settingService.ListSoraS3Profiles(c.Request.Context())
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ items := make([]dto.SoraS3Profile, 0, len(result.Items))
+ for idx := range result.Items {
+ items = append(items, toSoraS3ProfileDTO(result.Items[idx]))
+ }
+ response.Success(c, dto.ListSoraS3ProfilesResponse{
+ ActiveProfileID: result.ActiveProfileID,
+ Items: items,
+ })
+}
+
+// UpdateSoraS3SettingsRequest 更新/测试 Sora S3 配置请求(兼容旧接口)
+type UpdateSoraS3SettingsRequest struct {
+ ProfileID string `json:"profile_id"`
+ Enabled bool `json:"enabled"`
+ Endpoint string `json:"endpoint"`
+ Region string `json:"region"`
+ Bucket string `json:"bucket"`
+ AccessKeyID string `json:"access_key_id"`
+ SecretAccessKey string `json:"secret_access_key"`
+ Prefix string `json:"prefix"`
+ ForcePathStyle bool `json:"force_path_style"`
+ CDNURL string `json:"cdn_url"`
+ DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
+}
+
+type CreateSoraS3ProfileRequest struct {
+ ProfileID string `json:"profile_id"`
+ Name string `json:"name"`
+ SetActive bool `json:"set_active"`
+ Enabled bool `json:"enabled"`
+ Endpoint string `json:"endpoint"`
+ Region string `json:"region"`
+ Bucket string `json:"bucket"`
+ AccessKeyID string `json:"access_key_id"`
+ SecretAccessKey string `json:"secret_access_key"`
+ Prefix string `json:"prefix"`
+ ForcePathStyle bool `json:"force_path_style"`
+ CDNURL string `json:"cdn_url"`
+ DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
+}
+
+type UpdateSoraS3ProfileRequest struct {
+ Name string `json:"name"`
+ Enabled bool `json:"enabled"`
+ Endpoint string `json:"endpoint"`
+ Region string `json:"region"`
+ Bucket string `json:"bucket"`
+ AccessKeyID string `json:"access_key_id"`
+ SecretAccessKey string `json:"secret_access_key"`
+ Prefix string `json:"prefix"`
+ ForcePathStyle bool `json:"force_path_style"`
+ CDNURL string `json:"cdn_url"`
+ DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
+}
+
+// CreateSoraS3Profile 创建 Sora S3 配置
+// POST /api/v1/admin/settings/sora-s3/profiles
+func (h *SettingHandler) CreateSoraS3Profile(c *gin.Context) {
+ var req CreateSoraS3ProfileRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ if req.DefaultStorageQuotaBytes < 0 {
+ req.DefaultStorageQuotaBytes = 0
+ }
+ if strings.TrimSpace(req.Name) == "" {
+ response.BadRequest(c, "Name is required")
+ return
+ }
+ if strings.TrimSpace(req.ProfileID) == "" {
+ response.BadRequest(c, "Profile ID is required")
+ return
+ }
+ if err := validateSoraS3RequiredWhenEnabled(req.Enabled, req.Endpoint, req.Bucket, req.AccessKeyID, req.SecretAccessKey, false); err != nil {
+ response.BadRequest(c, err.Error())
+ return
+ }
+
+ created, err := h.settingService.CreateSoraS3Profile(c.Request.Context(), &service.SoraS3Profile{
+ ProfileID: req.ProfileID,
+ Name: req.Name,
+ Enabled: req.Enabled,
+ Endpoint: req.Endpoint,
+ Region: req.Region,
+ Bucket: req.Bucket,
+ AccessKeyID: req.AccessKeyID,
+ SecretAccessKey: req.SecretAccessKey,
+ Prefix: req.Prefix,
+ ForcePathStyle: req.ForcePathStyle,
+ CDNURL: req.CDNURL,
+ DefaultStorageQuotaBytes: req.DefaultStorageQuotaBytes,
+ }, req.SetActive)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ response.Success(c, toSoraS3ProfileDTO(*created))
+}
+
+// UpdateSoraS3Profile 更新 Sora S3 配置
+// PUT /api/v1/admin/settings/sora-s3/profiles/:profile_id
+func (h *SettingHandler) UpdateSoraS3Profile(c *gin.Context) {
+ profileID := strings.TrimSpace(c.Param("profile_id"))
+ if profileID == "" {
+ response.BadRequest(c, "Profile ID is required")
+ return
+ }
+
+ var req UpdateSoraS3ProfileRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ if req.DefaultStorageQuotaBytes < 0 {
+ req.DefaultStorageQuotaBytes = 0
+ }
+ if strings.TrimSpace(req.Name) == "" {
+ response.BadRequest(c, "Name is required")
+ return
+ }
+
+ existingList, err := h.settingService.ListSoraS3Profiles(c.Request.Context())
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ existing := findSoraS3ProfileByID(existingList.Items, profileID)
+ if existing == nil {
+ response.ErrorFrom(c, service.ErrSoraS3ProfileNotFound)
+ return
+ }
+ if err := validateSoraS3RequiredWhenEnabled(req.Enabled, req.Endpoint, req.Bucket, req.AccessKeyID, req.SecretAccessKey, existing.SecretAccessKeyConfigured); err != nil {
+ response.BadRequest(c, err.Error())
+ return
+ }
+
+ updated, updateErr := h.settingService.UpdateSoraS3Profile(c.Request.Context(), profileID, &service.SoraS3Profile{
+ Name: req.Name,
+ Enabled: req.Enabled,
+ Endpoint: req.Endpoint,
+ Region: req.Region,
+ Bucket: req.Bucket,
+ AccessKeyID: req.AccessKeyID,
+ SecretAccessKey: req.SecretAccessKey,
+ Prefix: req.Prefix,
+ ForcePathStyle: req.ForcePathStyle,
+ CDNURL: req.CDNURL,
+ DefaultStorageQuotaBytes: req.DefaultStorageQuotaBytes,
+ })
+ if updateErr != nil {
+ response.ErrorFrom(c, updateErr)
+ return
+ }
+
+ response.Success(c, toSoraS3ProfileDTO(*updated))
+}
+
+// DeleteSoraS3Profile 删除 Sora S3 配置
+// DELETE /api/v1/admin/settings/sora-s3/profiles/:profile_id
+func (h *SettingHandler) DeleteSoraS3Profile(c *gin.Context) {
+ profileID := strings.TrimSpace(c.Param("profile_id"))
+ if profileID == "" {
+ response.BadRequest(c, "Profile ID is required")
+ return
+ }
+ if err := h.settingService.DeleteSoraS3Profile(c.Request.Context(), profileID); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, gin.H{"deleted": true})
+}
+
+// SetActiveSoraS3Profile 切换激活 Sora S3 配置
+// POST /api/v1/admin/settings/sora-s3/profiles/:profile_id/activate
+func (h *SettingHandler) SetActiveSoraS3Profile(c *gin.Context) {
+ profileID := strings.TrimSpace(c.Param("profile_id"))
+ if profileID == "" {
+ response.BadRequest(c, "Profile ID is required")
+ return
+ }
+ active, err := h.settingService.SetActiveSoraS3Profile(c.Request.Context(), profileID)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, toSoraS3ProfileDTO(*active))
+}
+
+// UpdateSoraS3Settings 更新 Sora S3 存储配置(兼容旧单配置接口)
+// PUT /api/v1/admin/settings/sora-s3
+func (h *SettingHandler) UpdateSoraS3Settings(c *gin.Context) {
+ var req UpdateSoraS3SettingsRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ existing, err := h.settingService.GetSoraS3Settings(c.Request.Context())
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ if req.DefaultStorageQuotaBytes < 0 {
+ req.DefaultStorageQuotaBytes = 0
+ }
+ if err := validateSoraS3RequiredWhenEnabled(req.Enabled, req.Endpoint, req.Bucket, req.AccessKeyID, req.SecretAccessKey, existing.SecretAccessKeyConfigured); err != nil {
+ response.BadRequest(c, err.Error())
+ return
+ }
+
+ settings := &service.SoraS3Settings{
+ Enabled: req.Enabled,
+ Endpoint: req.Endpoint,
+ Region: req.Region,
+ Bucket: req.Bucket,
+ AccessKeyID: req.AccessKeyID,
+ SecretAccessKey: req.SecretAccessKey,
+ Prefix: req.Prefix,
+ ForcePathStyle: req.ForcePathStyle,
+ CDNURL: req.CDNURL,
+ DefaultStorageQuotaBytes: req.DefaultStorageQuotaBytes,
+ }
+ if err := h.settingService.SetSoraS3Settings(c.Request.Context(), settings); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ updatedSettings, err := h.settingService.GetSoraS3Settings(c.Request.Context())
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, toSoraS3SettingsDTO(updatedSettings))
+}
+
+// TestSoraS3Connection 测试 Sora S3 连接(HeadBucket)
+// POST /api/v1/admin/settings/sora-s3/test
+func (h *SettingHandler) TestSoraS3Connection(c *gin.Context) {
+ if h.soraS3Storage == nil {
+ response.Error(c, 500, "S3 存储服务未初始化")
+ return
+ }
+
+ var req UpdateSoraS3SettingsRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+ if !req.Enabled {
+ response.BadRequest(c, "S3 未启用,无法测试连接")
+ return
+ }
+
+ if req.SecretAccessKey == "" {
+ if req.ProfileID != "" {
+ profiles, err := h.settingService.ListSoraS3Profiles(c.Request.Context())
+ if err == nil {
+ profile := findSoraS3ProfileByID(profiles.Items, req.ProfileID)
+ if profile != nil {
+ req.SecretAccessKey = profile.SecretAccessKey
+ }
+ }
+ }
+ if req.SecretAccessKey == "" {
+ existing, err := h.settingService.GetSoraS3Settings(c.Request.Context())
+ if err == nil {
+ req.SecretAccessKey = existing.SecretAccessKey
+ }
+ }
+ }
+
+ testCfg := &service.SoraS3Settings{
+ Enabled: true,
+ Endpoint: req.Endpoint,
+ Region: req.Region,
+ Bucket: req.Bucket,
+ AccessKeyID: req.AccessKeyID,
+ SecretAccessKey: req.SecretAccessKey,
+ Prefix: req.Prefix,
+ ForcePathStyle: req.ForcePathStyle,
+ CDNURL: req.CDNURL,
+ }
+ if err := h.soraS3Storage.TestConnectionWithSettings(c.Request.Context(), testCfg); err != nil {
+ response.Error(c, 400, "S3 连接测试失败: "+err.Error())
+ return
+ }
+ response.Success(c, gin.H{"message": "S3 连接成功"})
+}
+
// UpdateStreamTimeoutSettingsRequest 更新流超时配置请求
type UpdateStreamTimeoutSettingsRequest struct {
Enabled bool `json:"enabled"`
diff --git a/backend/internal/handler/admin/usage_cleanup_handler_test.go b/backend/internal/handler/admin/usage_cleanup_handler_test.go
index ed1c7cc2..6152d5e9 100644
--- a/backend/internal/handler/admin/usage_cleanup_handler_test.go
+++ b/backend/internal/handler/admin/usage_cleanup_handler_test.go
@@ -225,6 +225,92 @@ func TestUsageHandlerCreateCleanupTaskInvalidEndDate(t *testing.T) {
require.Equal(t, http.StatusBadRequest, recorder.Code)
}
+func TestUsageHandlerCreateCleanupTaskInvalidRequestType(t *testing.T) {
+ repo := &cleanupRepoStub{}
+ cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
+ cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
+ router := setupCleanupRouter(cleanupService, 88)
+
+ payload := map[string]any{
+ "start_date": "2024-01-01",
+ "end_date": "2024-01-02",
+ "timezone": "UTC",
+ "request_type": "invalid",
+ }
+ body, err := json.Marshal(payload)
+ require.NoError(t, err)
+
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+ recorder := httptest.NewRecorder()
+ router.ServeHTTP(recorder, req)
+
+ require.Equal(t, http.StatusBadRequest, recorder.Code)
+}
+
+func TestUsageHandlerCreateCleanupTaskRequestTypePriority(t *testing.T) {
+ repo := &cleanupRepoStub{}
+ cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
+ cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
+ router := setupCleanupRouter(cleanupService, 99)
+
+ payload := map[string]any{
+ "start_date": "2024-01-01",
+ "end_date": "2024-01-02",
+ "timezone": "UTC",
+ "request_type": "ws_v2",
+ "stream": false,
+ }
+ body, err := json.Marshal(payload)
+ require.NoError(t, err)
+
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+ recorder := httptest.NewRecorder()
+ router.ServeHTTP(recorder, req)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ repo.mu.Lock()
+ defer repo.mu.Unlock()
+ require.Len(t, repo.created, 1)
+ created := repo.created[0]
+ require.NotNil(t, created.Filters.RequestType)
+ require.Equal(t, int16(service.RequestTypeWSV2), *created.Filters.RequestType)
+ require.Nil(t, created.Filters.Stream)
+}
+
+func TestUsageHandlerCreateCleanupTaskWithLegacyStream(t *testing.T) {
+ repo := &cleanupRepoStub{}
+ cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
+ cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
+ router := setupCleanupRouter(cleanupService, 99)
+
+ payload := map[string]any{
+ "start_date": "2024-01-01",
+ "end_date": "2024-01-02",
+ "timezone": "UTC",
+ "stream": true,
+ }
+ body, err := json.Marshal(payload)
+ require.NoError(t, err)
+
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+ recorder := httptest.NewRecorder()
+ router.ServeHTTP(recorder, req)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ repo.mu.Lock()
+ defer repo.mu.Unlock()
+ require.Len(t, repo.created, 1)
+ created := repo.created[0]
+ require.Nil(t, created.Filters.RequestType)
+ require.NotNil(t, created.Filters.Stream)
+ require.True(t, *created.Filters.Stream)
+}
+
func TestUsageHandlerCreateCleanupTaskSuccess(t *testing.T) {
repo := &cleanupRepoStub{}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
diff --git a/backend/internal/handler/admin/usage_handler.go b/backend/internal/handler/admin/usage_handler.go
index 5cbf18e6..d0bba773 100644
--- a/backend/internal/handler/admin/usage_handler.go
+++ b/backend/internal/handler/admin/usage_handler.go
@@ -51,6 +51,7 @@ type CreateUsageCleanupTaskRequest struct {
AccountID *int64 `json:"account_id"`
GroupID *int64 `json:"group_id"`
Model *string `json:"model"`
+ RequestType *string `json:"request_type"`
Stream *bool `json:"stream"`
BillingType *int8 `json:"billing_type"`
Timezone string `json:"timezone"`
@@ -101,8 +102,17 @@ func (h *UsageHandler) List(c *gin.Context) {
model := c.Query("model")
+ var requestType *int16
var stream *bool
- if streamStr := c.Query("stream"); streamStr != "" {
+ if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" {
+ parsed, err := service.ParseUsageRequestType(requestTypeStr)
+ if err != nil {
+ response.BadRequest(c, err.Error())
+ return
+ }
+ value := int16(parsed)
+ requestType = &value
+ } else if streamStr := c.Query("stream"); streamStr != "" {
val, err := strconv.ParseBool(streamStr)
if err != nil {
response.BadRequest(c, "Invalid stream value, use true or false")
@@ -152,6 +162,7 @@ func (h *UsageHandler) List(c *gin.Context) {
AccountID: accountID,
GroupID: groupID,
Model: model,
+ RequestType: requestType,
Stream: stream,
BillingType: billingType,
StartTime: startTime,
@@ -214,8 +225,17 @@ func (h *UsageHandler) Stats(c *gin.Context) {
model := c.Query("model")
+ var requestType *int16
var stream *bool
- if streamStr := c.Query("stream"); streamStr != "" {
+ if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" {
+ parsed, err := service.ParseUsageRequestType(requestTypeStr)
+ if err != nil {
+ response.BadRequest(c, err.Error())
+ return
+ }
+ value := int16(parsed)
+ requestType = &value
+ } else if streamStr := c.Query("stream"); streamStr != "" {
val, err := strconv.ParseBool(streamStr)
if err != nil {
response.BadRequest(c, "Invalid stream value, use true or false")
@@ -278,6 +298,7 @@ func (h *UsageHandler) Stats(c *gin.Context) {
AccountID: accountID,
GroupID: groupID,
Model: model,
+ RequestType: requestType,
Stream: stream,
BillingType: billingType,
StartTime: &startTime,
@@ -432,6 +453,19 @@ func (h *UsageHandler) CreateCleanupTask(c *gin.Context) {
}
endTime = endTime.Add(24*time.Hour - time.Nanosecond)
+ var requestType *int16
+ stream := req.Stream
+ if req.RequestType != nil {
+ parsed, err := service.ParseUsageRequestType(*req.RequestType)
+ if err != nil {
+ response.BadRequest(c, err.Error())
+ return
+ }
+ value := int16(parsed)
+ requestType = &value
+ stream = nil
+ }
+
filters := service.UsageCleanupFilters{
StartTime: startTime,
EndTime: endTime,
@@ -440,7 +474,8 @@ func (h *UsageHandler) CreateCleanupTask(c *gin.Context) {
AccountID: req.AccountID,
GroupID: req.GroupID,
Model: req.Model,
- Stream: req.Stream,
+ RequestType: requestType,
+ Stream: stream,
BillingType: req.BillingType,
}
@@ -464,9 +499,13 @@ func (h *UsageHandler) CreateCleanupTask(c *gin.Context) {
if filters.Model != nil {
model = *filters.Model
}
- var stream any
+ var streamValue any
if filters.Stream != nil {
- stream = *filters.Stream
+ streamValue = *filters.Stream
+ }
+ var requestTypeName any
+ if filters.RequestType != nil {
+ requestTypeName = service.RequestTypeFromInt16(*filters.RequestType).String()
}
var billingType any
if filters.BillingType != nil {
@@ -481,7 +520,7 @@ func (h *UsageHandler) CreateCleanupTask(c *gin.Context) {
Body: req,
}
executeAdminIdempotentJSON(c, "admin.usage.cleanup_tasks.create", idempotencyPayload, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
- logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 请求创建清理任务: operator=%d start=%s end=%s user_id=%v api_key_id=%v account_id=%v group_id=%v model=%v stream=%v billing_type=%v tz=%q",
+ logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 请求创建清理任务: operator=%d start=%s end=%s user_id=%v api_key_id=%v account_id=%v group_id=%v model=%v request_type=%v stream=%v billing_type=%v tz=%q",
subject.UserID,
filters.StartTime.Format(time.RFC3339),
filters.EndTime.Format(time.RFC3339),
@@ -490,7 +529,8 @@ func (h *UsageHandler) CreateCleanupTask(c *gin.Context) {
accountID,
groupID,
model,
- stream,
+ requestTypeName,
+ streamValue,
billingType,
req.Timezone,
)
diff --git a/backend/internal/handler/admin/usage_handler_request_type_test.go b/backend/internal/handler/admin/usage_handler_request_type_test.go
new file mode 100644
index 00000000..21add574
--- /dev/null
+++ b/backend/internal/handler/admin/usage_handler_request_type_test.go
@@ -0,0 +1,117 @@
+package admin
+
+import (
+ "context"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+)
+
+type adminUsageRepoCapture struct {
+ service.UsageLogRepository
+ listFilters usagestats.UsageLogFilters
+ statsFilters usagestats.UsageLogFilters
+}
+
+func (s *adminUsageRepoCapture) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]service.UsageLog, *pagination.PaginationResult, error) {
+ s.listFilters = filters
+ return []service.UsageLog{}, &pagination.PaginationResult{
+ Total: 0,
+ Page: params.Page,
+ PageSize: params.PageSize,
+ Pages: 0,
+ }, nil
+}
+
+func (s *adminUsageRepoCapture) GetStatsWithFilters(ctx context.Context, filters usagestats.UsageLogFilters) (*usagestats.UsageStats, error) {
+ s.statsFilters = filters
+ return &usagestats.UsageStats{}, nil
+}
+
+func newAdminUsageRequestTypeTestRouter(repo *adminUsageRepoCapture) *gin.Engine {
+ gin.SetMode(gin.TestMode)
+ usageSvc := service.NewUsageService(repo, nil, nil, nil)
+ handler := NewUsageHandler(usageSvc, nil, nil, nil)
+ router := gin.New()
+ router.GET("/admin/usage", handler.List)
+ router.GET("/admin/usage/stats", handler.Stats)
+ return router
+}
+
+func TestAdminUsageListRequestTypePriority(t *testing.T) {
+ repo := &adminUsageRepoCapture{}
+ router := newAdminUsageRequestTypeTestRouter(repo)
+
+ req := httptest.NewRequest(http.MethodGet, "/admin/usage?request_type=ws_v2&stream=false", nil)
+ rec := httptest.NewRecorder()
+ router.ServeHTTP(rec, req)
+
+ require.Equal(t, http.StatusOK, rec.Code)
+ require.NotNil(t, repo.listFilters.RequestType)
+ require.Equal(t, int16(service.RequestTypeWSV2), *repo.listFilters.RequestType)
+ require.Nil(t, repo.listFilters.Stream)
+}
+
+func TestAdminUsageListInvalidRequestType(t *testing.T) {
+ repo := &adminUsageRepoCapture{}
+ router := newAdminUsageRequestTypeTestRouter(repo)
+
+ req := httptest.NewRequest(http.MethodGet, "/admin/usage?request_type=bad", nil)
+ rec := httptest.NewRecorder()
+ router.ServeHTTP(rec, req)
+
+ require.Equal(t, http.StatusBadRequest, rec.Code)
+}
+
+func TestAdminUsageListInvalidStream(t *testing.T) {
+ repo := &adminUsageRepoCapture{}
+ router := newAdminUsageRequestTypeTestRouter(repo)
+
+ req := httptest.NewRequest(http.MethodGet, "/admin/usage?stream=bad", nil)
+ rec := httptest.NewRecorder()
+ router.ServeHTTP(rec, req)
+
+ require.Equal(t, http.StatusBadRequest, rec.Code)
+}
+
+func TestAdminUsageStatsRequestTypePriority(t *testing.T) {
+ repo := &adminUsageRepoCapture{}
+ router := newAdminUsageRequestTypeTestRouter(repo)
+
+ req := httptest.NewRequest(http.MethodGet, "/admin/usage/stats?request_type=stream&stream=bad", nil)
+ rec := httptest.NewRecorder()
+ router.ServeHTTP(rec, req)
+
+ require.Equal(t, http.StatusOK, rec.Code)
+ require.NotNil(t, repo.statsFilters.RequestType)
+ require.Equal(t, int16(service.RequestTypeStream), *repo.statsFilters.RequestType)
+ require.Nil(t, repo.statsFilters.Stream)
+}
+
+func TestAdminUsageStatsInvalidRequestType(t *testing.T) {
+ repo := &adminUsageRepoCapture{}
+ router := newAdminUsageRequestTypeTestRouter(repo)
+
+ req := httptest.NewRequest(http.MethodGet, "/admin/usage/stats?request_type=oops", nil)
+ rec := httptest.NewRecorder()
+ router.ServeHTTP(rec, req)
+
+ require.Equal(t, http.StatusBadRequest, rec.Code)
+}
+
+func TestAdminUsageStatsInvalidStream(t *testing.T) {
+ repo := &adminUsageRepoCapture{}
+ router := newAdminUsageRequestTypeTestRouter(repo)
+
+ req := httptest.NewRequest(http.MethodGet, "/admin/usage/stats?stream=oops", nil)
+ rec := httptest.NewRecorder()
+ router.ServeHTTP(rec, req)
+
+ require.Equal(t, http.StatusBadRequest, rec.Code)
+}
diff --git a/backend/internal/handler/admin/user_handler.go b/backend/internal/handler/admin/user_handler.go
index d85202e5..f85c060e 100644
--- a/backend/internal/handler/admin/user_handler.go
+++ b/backend/internal/handler/admin/user_handler.go
@@ -34,13 +34,14 @@ func NewUserHandler(adminService service.AdminService, concurrencyService *servi
// CreateUserRequest represents admin create user request
type CreateUserRequest struct {
- Email string `json:"email" binding:"required,email"`
- Password string `json:"password" binding:"required,min=6"`
- Username string `json:"username"`
- Notes string `json:"notes"`
- Balance float64 `json:"balance"`
- Concurrency int `json:"concurrency"`
- AllowedGroups []int64 `json:"allowed_groups"`
+ Email string `json:"email" binding:"required,email"`
+ Password string `json:"password" binding:"required,min=6"`
+ Username string `json:"username"`
+ Notes string `json:"notes"`
+ Balance float64 `json:"balance"`
+ Concurrency int `json:"concurrency"`
+ AllowedGroups []int64 `json:"allowed_groups"`
+ SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"`
}
// UpdateUserRequest represents admin update user request
@@ -56,7 +57,8 @@ type UpdateUserRequest struct {
AllowedGroups *[]int64 `json:"allowed_groups"`
// GroupRates 用户专属分组倍率配置
// map[groupID]*rate,nil 表示删除该分组的专属倍率
- GroupRates map[int64]*float64 `json:"group_rates"`
+ GroupRates map[int64]*float64 `json:"group_rates"`
+ SoraStorageQuotaBytes *int64 `json:"sora_storage_quota_bytes"`
}
// UpdateBalanceRequest represents balance update request
@@ -174,13 +176,14 @@ func (h *UserHandler) Create(c *gin.Context) {
}
user, err := h.adminService.CreateUser(c.Request.Context(), &service.CreateUserInput{
- Email: req.Email,
- Password: req.Password,
- Username: req.Username,
- Notes: req.Notes,
- Balance: req.Balance,
- Concurrency: req.Concurrency,
- AllowedGroups: req.AllowedGroups,
+ Email: req.Email,
+ Password: req.Password,
+ Username: req.Username,
+ Notes: req.Notes,
+ Balance: req.Balance,
+ Concurrency: req.Concurrency,
+ AllowedGroups: req.AllowedGroups,
+ SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
})
if err != nil {
response.ErrorFrom(c, err)
@@ -207,15 +210,16 @@ func (h *UserHandler) Update(c *gin.Context) {
// 使用指针类型直接传递,nil 表示未提供该字段
user, err := h.adminService.UpdateUser(c.Request.Context(), userID, &service.UpdateUserInput{
- Email: req.Email,
- Password: req.Password,
- Username: req.Username,
- Notes: req.Notes,
- Balance: req.Balance,
- Concurrency: req.Concurrency,
- Status: req.Status,
- AllowedGroups: req.AllowedGroups,
- GroupRates: req.GroupRates,
+ Email: req.Email,
+ Password: req.Password,
+ Username: req.Username,
+ Notes: req.Notes,
+ Balance: req.Balance,
+ Concurrency: req.Concurrency,
+ Status: req.Status,
+ AllowedGroups: req.AllowedGroups,
+ GroupRates: req.GroupRates,
+ SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
})
if err != nil {
response.ErrorFrom(c, err)
diff --git a/backend/internal/handler/auth_handler.go b/backend/internal/handler/auth_handler.go
index e0078e14..1ffa9d71 100644
--- a/backend/internal/handler/auth_handler.go
+++ b/backend/internal/handler/auth_handler.go
@@ -113,9 +113,8 @@ func (h *AuthHandler) Register(c *gin.Context) {
return
}
- // Turnstile 验证 — 始终执行,防止绕过
- // TODO: 确认前端在提交邮箱验证码注册时也传递了 turnstile_token
- if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, ip.GetClientIP(c)); err != nil {
+ // Turnstile 验证(邮箱验证码注册场景避免重复校验一次性 token)
+ if err := h.authService.VerifyTurnstileForRegister(c.Request.Context(), req.TurnstileToken, ip.GetClientIP(c), req.VerifyCode); err != nil {
response.ErrorFrom(c, err)
return
}
diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go
index 42ff4a84..f8298067 100644
--- a/backend/internal/handler/dto/mappers.go
+++ b/backend/internal/handler/dto/mappers.go
@@ -59,9 +59,11 @@ func UserFromServiceAdmin(u *service.User) *AdminUser {
return nil
}
return &AdminUser{
- User: *base,
- Notes: u.Notes,
- GroupRates: u.GroupRates,
+ User: *base,
+ Notes: u.Notes,
+ GroupRates: u.GroupRates,
+ SoraStorageQuotaBytes: u.SoraStorageQuotaBytes,
+ SoraStorageUsedBytes: u.SoraStorageUsedBytes,
}
}
@@ -152,6 +154,7 @@ func groupFromServiceBase(g *service.Group) Group {
ClaudeCodeOnly: g.ClaudeCodeOnly,
FallbackGroupID: g.FallbackGroupID,
FallbackGroupIDOnInvalidRequest: g.FallbackGroupIDOnInvalidRequest,
+ SoraStorageQuotaBytes: g.SoraStorageQuotaBytes,
CreatedAt: g.CreatedAt,
UpdatedAt: g.UpdatedAt,
}
@@ -206,6 +209,13 @@ func AccountFromServiceShallow(a *service.Account) *Account {
if idleTimeout := a.GetSessionIdleTimeoutMinutes(); idleTimeout > 0 {
out.SessionIdleTimeoutMin = &idleTimeout
}
+ if rpm := a.GetBaseRPM(); rpm > 0 {
+ out.BaseRPM = &rpm
+ strategy := a.GetRPMStrategy()
+ out.RPMStrategy = &strategy
+ buffer := a.GetRPMStickyBuffer()
+ out.RPMStickyBuffer = &buffer
+ }
// TLS指纹伪装开关
if a.IsTLSFingerprintEnabled() {
enabled := true
@@ -283,7 +293,6 @@ func ProxyFromService(p *service.Proxy) *Proxy {
Host: p.Host,
Port: p.Port,
Username: p.Username,
- Password: p.Password,
Status: p.Status,
CreatedAt: p.CreatedAt,
UpdatedAt: p.UpdatedAt,
@@ -313,6 +322,51 @@ func ProxyWithAccountCountFromService(p *service.ProxyWithAccountCount) *ProxyWi
}
}
+// ProxyFromServiceAdmin converts a service Proxy to AdminProxy DTO for admin users.
+// It includes the password field - user-facing endpoints must not use this.
+func ProxyFromServiceAdmin(p *service.Proxy) *AdminProxy {
+ if p == nil {
+ return nil
+ }
+ base := ProxyFromService(p)
+ if base == nil {
+ return nil
+ }
+ return &AdminProxy{
+ Proxy: *base,
+ Password: p.Password,
+ }
+}
+
+// ProxyWithAccountCountFromServiceAdmin converts a service ProxyWithAccountCount to AdminProxyWithAccountCount DTO.
+// It includes the password field - user-facing endpoints must not use this.
+func ProxyWithAccountCountFromServiceAdmin(p *service.ProxyWithAccountCount) *AdminProxyWithAccountCount {
+ if p == nil {
+ return nil
+ }
+ admin := ProxyFromServiceAdmin(&p.Proxy)
+ if admin == nil {
+ return nil
+ }
+ return &AdminProxyWithAccountCount{
+ AdminProxy: *admin,
+ AccountCount: p.AccountCount,
+ LatencyMs: p.LatencyMs,
+ LatencyStatus: p.LatencyStatus,
+ LatencyMessage: p.LatencyMessage,
+ IPAddress: p.IPAddress,
+ Country: p.Country,
+ CountryCode: p.CountryCode,
+ Region: p.Region,
+ City: p.City,
+ QualityStatus: p.QualityStatus,
+ QualityScore: p.QualityScore,
+ QualityGrade: p.QualityGrade,
+ QualitySummary: p.QualitySummary,
+ QualityChecked: p.QualityChecked,
+ }
+}
+
func ProxyAccountSummaryFromService(a *service.ProxyAccountSummary) *ProxyAccountSummary {
if a == nil {
return nil
@@ -385,6 +439,8 @@ func AccountSummaryFromService(a *service.Account) *AccountSummary {
func usageLogFromServiceUser(l *service.UsageLog) UsageLog {
// 普通用户 DTO:严禁包含管理员字段(例如 account_rate_multiplier、ip_address、account)。
+ requestType := l.EffectiveRequestType()
+ stream, openAIWSMode := service.ApplyLegacyRequestFields(requestType, l.Stream, l.OpenAIWSMode)
return UsageLog{
ID: l.ID,
UserID: l.UserID,
@@ -409,7 +465,9 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog {
ActualCost: l.ActualCost,
RateMultiplier: l.RateMultiplier,
BillingType: l.BillingType,
- Stream: l.Stream,
+ RequestType: requestType.String(),
+ Stream: stream,
+ OpenAIWSMode: openAIWSMode,
DurationMs: l.DurationMs,
FirstTokenMs: l.FirstTokenMs,
ImageCount: l.ImageCount,
@@ -464,6 +522,7 @@ func UsageCleanupTaskFromService(task *service.UsageCleanupTask) *UsageCleanupTa
AccountID: task.Filters.AccountID,
GroupID: task.Filters.GroupID,
Model: task.Filters.Model,
+ RequestType: requestTypeStringPtr(task.Filters.RequestType),
Stream: task.Filters.Stream,
BillingType: task.Filters.BillingType,
},
@@ -479,6 +538,14 @@ func UsageCleanupTaskFromService(task *service.UsageCleanupTask) *UsageCleanupTa
}
}
+func requestTypeStringPtr(requestType *int16) *string {
+ if requestType == nil {
+ return nil
+ }
+ value := service.RequestTypeFromInt16(*requestType).String()
+ return &value
+}
+
func SettingFromService(s *service.Setting) *Setting {
if s == nil {
return nil
diff --git a/backend/internal/handler/dto/mappers_usage_test.go b/backend/internal/handler/dto/mappers_usage_test.go
new file mode 100644
index 00000000..d716bdc4
--- /dev/null
+++ b/backend/internal/handler/dto/mappers_usage_test.go
@@ -0,0 +1,73 @@
+package dto
+
+import (
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/stretchr/testify/require"
+)
+
+func TestUsageLogFromService_IncludesOpenAIWSMode(t *testing.T) {
+ t.Parallel()
+
+ wsLog := &service.UsageLog{
+ RequestID: "req_1",
+ Model: "gpt-5.3-codex",
+ OpenAIWSMode: true,
+ }
+ httpLog := &service.UsageLog{
+ RequestID: "resp_1",
+ Model: "gpt-5.3-codex",
+ OpenAIWSMode: false,
+ }
+
+ require.True(t, UsageLogFromService(wsLog).OpenAIWSMode)
+ require.False(t, UsageLogFromService(httpLog).OpenAIWSMode)
+ require.True(t, UsageLogFromServiceAdmin(wsLog).OpenAIWSMode)
+ require.False(t, UsageLogFromServiceAdmin(httpLog).OpenAIWSMode)
+}
+
+func TestUsageLogFromService_PrefersRequestTypeForLegacyFields(t *testing.T) {
+ t.Parallel()
+
+ log := &service.UsageLog{
+ RequestID: "req_2",
+ Model: "gpt-5.3-codex",
+ RequestType: service.RequestTypeWSV2,
+ Stream: false,
+ OpenAIWSMode: false,
+ }
+
+ userDTO := UsageLogFromService(log)
+ adminDTO := UsageLogFromServiceAdmin(log)
+
+ require.Equal(t, "ws_v2", userDTO.RequestType)
+ require.True(t, userDTO.Stream)
+ require.True(t, userDTO.OpenAIWSMode)
+ require.Equal(t, "ws_v2", adminDTO.RequestType)
+ require.True(t, adminDTO.Stream)
+ require.True(t, adminDTO.OpenAIWSMode)
+}
+
+func TestUsageCleanupTaskFromService_RequestTypeMapping(t *testing.T) {
+ t.Parallel()
+
+ requestType := int16(service.RequestTypeStream)
+ task := &service.UsageCleanupTask{
+ ID: 1,
+ Status: service.UsageCleanupStatusPending,
+ Filters: service.UsageCleanupFilters{
+ RequestType: &requestType,
+ },
+ }
+
+ dtoTask := UsageCleanupTaskFromService(task)
+ require.NotNil(t, dtoTask)
+ require.NotNil(t, dtoTask.Filters.RequestType)
+ require.Equal(t, "stream", *dtoTask.Filters.RequestType)
+}
+
+func TestRequestTypeStringPtrNil(t *testing.T) {
+ t.Parallel()
+ require.Nil(t, requestTypeStringPtr(nil))
+}
diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go
index be94bc16..e9086010 100644
--- a/backend/internal/handler/dto/settings.go
+++ b/backend/internal/handler/dto/settings.go
@@ -37,9 +37,11 @@ type SystemSettings struct {
HideCcsImportButton bool `json:"hide_ccs_import_button"`
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
PurchaseSubscriptionURL string `json:"purchase_subscription_url"`
+ SoraClientEnabled bool `json:"sora_client_enabled"`
- DefaultConcurrency int `json:"default_concurrency"`
- DefaultBalance float64 `json:"default_balance"`
+ DefaultConcurrency int `json:"default_concurrency"`
+ DefaultBalance float64 `json:"default_balance"`
+ DefaultSubscriptions []DefaultSubscriptionSetting `json:"default_subscriptions"`
// Model fallback configuration
EnableModelFallback bool `json:"enable_model_fallback"`
@@ -57,6 +59,13 @@ type SystemSettings struct {
OpsRealtimeMonitoringEnabled bool `json:"ops_realtime_monitoring_enabled"`
OpsQueryModeDefault string `json:"ops_query_mode_default"`
OpsMetricsIntervalSeconds int `json:"ops_metrics_interval_seconds"`
+
+ MinClaudeCodeVersion string `json:"min_claude_code_version"`
+}
+
+type DefaultSubscriptionSetting struct {
+ GroupID int64 `json:"group_id"`
+ ValidityDays int `json:"validity_days"`
}
type PublicSettings struct {
@@ -79,9 +88,48 @@ type PublicSettings struct {
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
PurchaseSubscriptionURL string `json:"purchase_subscription_url"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
+ SoraClientEnabled bool `json:"sora_client_enabled"`
Version string `json:"version"`
}
+// SoraS3Settings Sora S3 存储配置 DTO(响应用,不含敏感字段)
+type SoraS3Settings struct {
+ Enabled bool `json:"enabled"`
+ Endpoint string `json:"endpoint"`
+ Region string `json:"region"`
+ Bucket string `json:"bucket"`
+ AccessKeyID string `json:"access_key_id"`
+ SecretAccessKeyConfigured bool `json:"secret_access_key_configured"`
+ Prefix string `json:"prefix"`
+ ForcePathStyle bool `json:"force_path_style"`
+ CDNURL string `json:"cdn_url"`
+ DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
+}
+
+// SoraS3Profile Sora S3 存储配置项 DTO(响应用,不含敏感字段)
+type SoraS3Profile struct {
+ ProfileID string `json:"profile_id"`
+ Name string `json:"name"`
+ IsActive bool `json:"is_active"`
+ Enabled bool `json:"enabled"`
+ Endpoint string `json:"endpoint"`
+ Region string `json:"region"`
+ Bucket string `json:"bucket"`
+ AccessKeyID string `json:"access_key_id"`
+ SecretAccessKeyConfigured bool `json:"secret_access_key_configured"`
+ Prefix string `json:"prefix"`
+ ForcePathStyle bool `json:"force_path_style"`
+ CDNURL string `json:"cdn_url"`
+ DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
+ UpdatedAt string `json:"updated_at"`
+}
+
+// ListSoraS3ProfilesResponse Sora S3 配置列表响应
+type ListSoraS3ProfilesResponse struct {
+ ActiveProfileID string `json:"active_profile_id"`
+ Items []SoraS3Profile `json:"items"`
+}
+
// StreamTimeoutSettings 流超时处理配置 DTO
type StreamTimeoutSettings struct {
Enabled bool `json:"enabled"`
diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go
index 0cd1b241..b5c0640f 100644
--- a/backend/internal/handler/dto/types.go
+++ b/backend/internal/handler/dto/types.go
@@ -26,7 +26,9 @@ type AdminUser struct {
Notes string `json:"notes"`
// GroupRates 用户专属分组倍率配置
// map[groupID]rateMultiplier
- GroupRates map[int64]float64 `json:"group_rates,omitempty"`
+ GroupRates map[int64]float64 `json:"group_rates,omitempty"`
+ SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"`
+ SoraStorageUsedBytes int64 `json:"sora_storage_used_bytes"`
}
type APIKey struct {
@@ -80,6 +82,9 @@ type Group struct {
// 无效请求兜底分组
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"`
+ // Sora 存储配额
+ SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"`
+
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
@@ -148,6 +153,12 @@ type Account struct {
MaxSessions *int `json:"max_sessions,omitempty"`
SessionIdleTimeoutMin *int `json:"session_idle_timeout_minutes,omitempty"`
+ // RPM 限制(仅 Anthropic OAuth/SetupToken 账号有效)
+ // 从 extra 字段提取,方便前端显示和编辑
+ BaseRPM *int `json:"base_rpm,omitempty"`
+ RPMStrategy *string `json:"rpm_strategy,omitempty"`
+ RPMStickyBuffer *int `json:"rpm_sticky_buffer,omitempty"`
+
// TLS指纹伪装(仅 Anthropic OAuth/SetupToken 账号有效)
// 从 extra 字段提取,方便前端显示和编辑
EnableTLSFingerprint *bool `json:"enable_tls_fingerprint,omitempty"`
@@ -210,6 +221,32 @@ type ProxyWithAccountCount struct {
QualityChecked *int64 `json:"quality_checked,omitempty"`
}
+// AdminProxy 是管理员接口使用的 proxy DTO(包含密码等敏感字段)。
+// 注意:普通接口不得使用此 DTO。
+type AdminProxy struct {
+ Proxy
+ Password string `json:"password,omitempty"`
+}
+
+// AdminProxyWithAccountCount 是管理员接口使用的带账号统计的 proxy DTO。
+type AdminProxyWithAccountCount struct {
+ AdminProxy
+ AccountCount int64 `json:"account_count"`
+ LatencyMs *int64 `json:"latency_ms,omitempty"`
+ LatencyStatus string `json:"latency_status,omitempty"`
+ LatencyMessage string `json:"latency_message,omitempty"`
+ IPAddress string `json:"ip_address,omitempty"`
+ Country string `json:"country,omitempty"`
+ CountryCode string `json:"country_code,omitempty"`
+ Region string `json:"region,omitempty"`
+ City string `json:"city,omitempty"`
+ QualityStatus string `json:"quality_status,omitempty"`
+ QualityScore *int `json:"quality_score,omitempty"`
+ QualityGrade string `json:"quality_grade,omitempty"`
+ QualitySummary string `json:"quality_summary,omitempty"`
+ QualityChecked *int64 `json:"quality_checked,omitempty"`
+}
+
type ProxyAccountSummary struct {
ID int64 `json:"id"`
Name string `json:"name"`
@@ -278,10 +315,12 @@ type UsageLog struct {
ActualCost float64 `json:"actual_cost"`
RateMultiplier float64 `json:"rate_multiplier"`
- BillingType int8 `json:"billing_type"`
- Stream bool `json:"stream"`
- DurationMs *int `json:"duration_ms"`
- FirstTokenMs *int `json:"first_token_ms"`
+ BillingType int8 `json:"billing_type"`
+ RequestType string `json:"request_type"`
+ Stream bool `json:"stream"`
+ OpenAIWSMode bool `json:"openai_ws_mode"`
+ DurationMs *int `json:"duration_ms"`
+ FirstTokenMs *int `json:"first_token_ms"`
// 图片生成字段
ImageCount int `json:"image_count"`
@@ -324,6 +363,7 @@ type UsageCleanupFilters struct {
AccountID *int64 `json:"account_id,omitempty"`
GroupID *int64 `json:"group_id,omitempty"`
Model *string `json:"model,omitempty"`
+ RequestType *string `json:"request_type,omitempty"`
Stream *bool `json:"stream,omitempty"`
BillingType *int8 `json:"billing_type,omitempty"`
}
diff --git a/backend/internal/handler/failover_loop.go b/backend/internal/handler/failover_loop.go
index 1f8a7e9a..b2583301 100644
--- a/backend/internal/handler/failover_loop.go
+++ b/backend/internal/handler/failover_loop.go
@@ -2,11 +2,12 @@ package handler
import (
"context"
- "log"
"net/http"
"time"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/service"
+ "go.uber.org/zap"
)
// TempUnscheduler 用于 HandleFailoverError 中同账号重试耗尽后的临时封禁。
@@ -78,8 +79,12 @@ func (s *FailoverState) HandleFailoverError(
// 同账号重试:对 RetryableOnSameAccount 的临时性错误,先在同一账号上重试
if failoverErr.RetryableOnSameAccount && s.SameAccountRetryCount[accountID] < maxSameAccountRetries {
s.SameAccountRetryCount[accountID]++
- log.Printf("Account %d: retryable error %d, same-account retry %d/%d",
- accountID, failoverErr.StatusCode, s.SameAccountRetryCount[accountID], maxSameAccountRetries)
+ logger.FromContext(ctx).Warn("gateway.failover_same_account_retry",
+ zap.Int64("account_id", accountID),
+ zap.Int("upstream_status", failoverErr.StatusCode),
+ zap.Int("same_account_retry_count", s.SameAccountRetryCount[accountID]),
+ zap.Int("same_account_retry_max", maxSameAccountRetries),
+ )
if !sleepWithContext(ctx, sameAccountRetryDelay) {
return FailoverCanceled
}
@@ -101,8 +106,12 @@ func (s *FailoverState) HandleFailoverError(
// 递增切换计数
s.SwitchCount++
- log.Printf("Account %d: upstream error %d, switching account %d/%d",
- accountID, failoverErr.StatusCode, s.SwitchCount, s.MaxSwitches)
+ logger.FromContext(ctx).Warn("gateway.failover_switch_account",
+ zap.Int64("account_id", accountID),
+ zap.Int("upstream_status", failoverErr.StatusCode),
+ zap.Int("switch_count", s.SwitchCount),
+ zap.Int("max_switches", s.MaxSwitches),
+ )
// Antigravity 平台换号线性递增延时
if platform == service.PlatformAntigravity {
@@ -127,13 +136,18 @@ func (s *FailoverState) HandleSelectionExhausted(ctx context.Context) FailoverAc
s.LastFailoverErr.StatusCode == http.StatusServiceUnavailable &&
s.SwitchCount <= s.MaxSwitches {
- log.Printf("Antigravity single-account 503 backoff: waiting %v before retry (attempt %d)",
- singleAccountBackoffDelay, s.SwitchCount)
+ logger.FromContext(ctx).Warn("gateway.failover_single_account_backoff",
+ zap.Duration("backoff_delay", singleAccountBackoffDelay),
+ zap.Int("switch_count", s.SwitchCount),
+ zap.Int("max_switches", s.MaxSwitches),
+ )
if !sleepWithContext(ctx, singleAccountBackoffDelay) {
return FailoverCanceled
}
- log.Printf("Antigravity single-account 503 retry: clearing failed accounts, retry %d/%d",
- s.SwitchCount, s.MaxSwitches)
+ logger.FromContext(ctx).Warn("gateway.failover_single_account_retry",
+ zap.Int("switch_count", s.SwitchCount),
+ zap.Int("max_switches", s.MaxSwitches),
+ )
s.FailedAccountIDs = make(map[int64]struct{})
return FailoverContinue
}
diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go
index fe40e9d2..2bd59f32 100644
--- a/backend/internal/handler/gateway_handler.go
+++ b/backend/internal/handler/gateway_handler.go
@@ -6,9 +6,10 @@ import (
"encoding/json"
"errors"
"fmt"
- "io"
"net/http"
+ "strconv"
"strings"
+ "sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
@@ -17,6 +18,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
pkgerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
+ pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
@@ -27,6 +29,10 @@ import (
"go.uber.org/zap"
)
+const gatewayCompatibilityMetricsLogInterval = 1024
+
+var gatewayCompatibilityMetricsLogCounter atomic.Uint64
+
// GatewayHandler handles API gateway requests
type GatewayHandler struct {
gatewayService *service.GatewayService
@@ -42,6 +48,7 @@ type GatewayHandler struct {
maxAccountSwitches int
maxAccountSwitchesGemini int
cfg *config.Config
+ settingService *service.SettingService
}
// NewGatewayHandler creates a new GatewayHandler
@@ -57,6 +64,7 @@ func NewGatewayHandler(
usageRecordWorkerPool *service.UsageRecordWorkerPool,
errorPassthroughService *service.ErrorPassthroughService,
cfg *config.Config,
+ settingService *service.SettingService,
) *GatewayHandler {
pingInterval := time.Duration(0)
maxAccountSwitches := 10
@@ -84,6 +92,7 @@ func NewGatewayHandler(
maxAccountSwitches: maxAccountSwitches,
maxAccountSwitchesGemini: maxAccountSwitchesGemini,
cfg: cfg,
+ settingService: settingService,
}
}
@@ -109,9 +118,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID),
)
+ defer h.maybeLogCompatibilityFallbackMetrics(reqLog)
// 读取请求体
- body, err := io.ReadAll(c.Request.Body)
+ body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
if err != nil {
if maxErr, ok := extractMaxBytesError(err); ok {
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
@@ -140,16 +150,21 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 设置 max_tokens=1 + haiku 探测请求标识到 context 中
// 必须在 SetClaudeCodeClientContext 之前设置,因为 ClaudeCodeValidator 需要读取此标识进行绕过判断
if isMaxTokensOneHaikuRequest(reqModel, parsedReq.MaxTokens, reqStream) {
- ctx := context.WithValue(c.Request.Context(), ctxkey.IsMaxTokensOneHaikuRequest, true)
+ ctx := service.WithIsMaxTokensOneHaikuRequest(c.Request.Context(), true, h.metadataBridgeEnabled())
c.Request = c.Request.WithContext(ctx)
}
- // 检查是否为 Claude Code 客户端,设置到 context 中
- SetClaudeCodeClientContext(c, body)
+ // 检查是否为 Claude Code 客户端,设置到 context 中(复用已解析请求,避免二次反序列化)。
+ SetClaudeCodeClientContext(c, body, parsedReq)
isClaudeCodeClient := service.IsClaudeCodeClient(c.Request.Context())
+ // 版本检查:仅对 Claude Code 客户端,拒绝低于最低版本的请求
+ if !h.checkClaudeCodeVersion(c) {
+ return
+ }
+
// 在请求上下文中记录 thinking 状态,供 Antigravity 最终模型 key 推导/模型维度限流使用
- c.Request = c.Request.WithContext(context.WithValue(c.Request.Context(), ctxkey.ThinkingEnabled, parsedReq.ThinkingEnabled))
+ c.Request = c.Request.WithContext(service.WithThinkingEnabled(c.Request.Context(), parsedReq.ThinkingEnabled, h.metadataBridgeEnabled()))
setOpsRequestContext(c, reqModel, reqStream, body)
@@ -247,8 +262,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if apiKey.GroupID != nil {
prefetchedGroupID = *apiKey.GroupID
}
- ctx := context.WithValue(c.Request.Context(), ctxkey.PrefetchedStickyAccountID, sessionBoundAccountID)
- ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyGroupID, prefetchedGroupID)
+ ctx := service.WithPrefetchedStickySession(c.Request.Context(), sessionBoundAccountID, prefetchedGroupID, h.metadataBridgeEnabled())
c.Request = c.Request.WithContext(ctx)
}
}
@@ -261,7 +275,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 单账号分组提前设置 SingleAccountRetry 标记,让 Service 层首次 503 就不设模型限流标记。
// 避免单账号分组收到 503 (MODEL_CAPACITY_EXHAUSTED) 时设 29s 限流,导致后续请求连续快速失败。
if h.gatewayService.IsSingleAntigravityAccountGroup(c.Request.Context(), apiKey.GroupID) {
- ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
+ ctx := service.WithSingleAccountRetry(c.Request.Context(), true, h.metadataBridgeEnabled())
c.Request = c.Request.WithContext(ctx)
}
@@ -275,7 +289,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
action := fs.HandleSelectionExhausted(c.Request.Context())
switch action {
case FailoverContinue:
- ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
+ ctx := service.WithSingleAccountRetry(c.Request.Context(), true, h.metadataBridgeEnabled())
c.Request = c.Request.WithContext(ctx)
continue
case FailoverCanceled:
@@ -364,7 +378,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
var result *service.ForwardResult
requestCtx := c.Request.Context()
if fs.SwitchCount > 0 {
- requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, fs.SwitchCount)
+ requestCtx = service.WithAccountSwitchCount(requestCtx, fs.SwitchCount, h.metadataBridgeEnabled())
}
if account.Platform == service.PlatformAntigravity {
result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, reqModel, "generateContent", reqStream, body, hasBoundSession)
@@ -397,6 +411,15 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return
}
+ // RPM 计数递增(Forward 成功后)
+ // 注意:TOCTOU 竞态是已知且可接受的设计权衡,与 WindowCost 一致的 soft-limit 模式。
+ // 在高并发下可能短暂超出 RPM 限制,但不会导致请求失败。
+ if account.IsAnthropicOAuthOrSetupToken() && account.GetBaseRPM() > 0 {
+ if err := h.gatewayService.IncrementAccountRPM(c.Request.Context(), account.ID); err != nil {
+ reqLog.Warn("gateway.rpm_increment_failed", zap.Int64("account_id", account.ID), zap.Error(err))
+ }
+ }
+
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
@@ -439,7 +462,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 单账号分组提前设置 SingleAccountRetry 标记,让 Service 层首次 503 就不设模型限流标记。
// 避免单账号分组收到 503 (MODEL_CAPACITY_EXHAUSTED) 时设 29s 限流,导致后续请求连续快速失败。
if h.gatewayService.IsSingleAntigravityAccountGroup(c.Request.Context(), currentAPIKey.GroupID) {
- ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
+ ctx := service.WithSingleAccountRetry(c.Request.Context(), true, h.metadataBridgeEnabled())
c.Request = c.Request.WithContext(ctx)
}
@@ -458,7 +481,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
action := fs.HandleSelectionExhausted(c.Request.Context())
switch action {
case FailoverContinue:
- ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
+ ctx := service.WithSingleAccountRetry(c.Request.Context(), true, h.metadataBridgeEnabled())
c.Request = c.Request.WithContext(ctx)
continue
case FailoverCanceled:
@@ -547,7 +570,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
var result *service.ForwardResult
requestCtx := c.Request.Context()
if fs.SwitchCount > 0 {
- requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, fs.SwitchCount)
+ requestCtx = service.WithAccountSwitchCount(requestCtx, fs.SwitchCount, h.metadataBridgeEnabled())
}
if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey {
result, err = h.antigravityGatewayService.Forward(requestCtx, c, account, body, hasBoundSession)
@@ -589,7 +612,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
h.handleStreamingAwareError(c, status, code, message, streamStarted)
return
}
- // 兜底重试按“直接请求兜底分组”处理:清除强制平台,允许按分组平台调度
+ // 兜底重试按"直接请求兜底分组"处理:清除强制平台,允许按分组平台调度
ctx := context.WithValue(c.Request.Context(), ctxkey.ForcePlatform, "")
c.Request = c.Request.WithContext(ctx)
currentAPIKey = fallbackAPIKey
@@ -623,6 +646,15 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return
}
+ // RPM 计数递增(Forward 成功后)
+ // 注意:TOCTOU 竞态是已知且可接受的设计权衡,与 WindowCost 一致的 soft-limit 模式。
+ // 在高并发下可能短暂超出 RPM 限制,但不会导致请求失败。
+ if account.IsAnthropicOAuthOrSetupToken() && account.GetBaseRPM() > 0 {
+ if err := h.gatewayService.IncrementAccountRPM(c.Request.Context(), account.ID); err != nil {
+ reqLog.Warn("gateway.rpm_increment_failed", zap.Int64("account_id", account.ID), zap.Error(err))
+ }
+ }
+
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
@@ -956,20 +988,8 @@ func (h *GatewayHandler) handleStreamingAwareError(c *gin.Context, status int, e
// Stream already started, send error as SSE event then close
flusher, ok := c.Writer.(http.Flusher)
if ok {
- // Send error event in SSE format with proper JSON marshaling
- errorData := map[string]any{
- "type": "error",
- "error": map[string]string{
- "type": errType,
- "message": message,
- },
- }
- jsonBytes, err := json.Marshal(errorData)
- if err != nil {
- _ = c.Error(err)
- return
- }
- errorEvent := fmt.Sprintf("data: %s\n\n", string(jsonBytes))
+ // SSE 错误事件固定 schema,使用 Quote 直拼可避免额外 Marshal 分配。
+ errorEvent := `data: {"type":"error","error":{"type":` + strconv.Quote(errType) + `,"message":` + strconv.Quote(message) + `}}` + "\n\n"
if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil {
_ = c.Error(err)
}
@@ -991,6 +1011,41 @@ func (h *GatewayHandler) ensureForwardErrorResponse(c *gin.Context, streamStarte
return true
}
+// checkClaudeCodeVersion 检查 Claude Code 客户端版本是否满足最低要求
+// 仅对已识别的 Claude Code 客户端执行,count_tokens 路径除外
+func (h *GatewayHandler) checkClaudeCodeVersion(c *gin.Context) bool {
+ ctx := c.Request.Context()
+ if !service.IsClaudeCodeClient(ctx) {
+ return true
+ }
+
+ // 排除 count_tokens 子路径
+ if strings.HasSuffix(c.Request.URL.Path, "/count_tokens") {
+ return true
+ }
+
+ minVersion := h.settingService.GetMinClaudeCodeVersion(ctx)
+ if minVersion == "" {
+ return true // 未设置,不检查
+ }
+
+ clientVersion := service.GetClaudeCodeVersion(ctx)
+ if clientVersion == "" {
+ h.errorResponse(c, http.StatusBadRequest, "invalid_request_error",
+ "Unable to determine Claude Code version. Please update Claude Code: npm update -g @anthropic-ai/claude-code")
+ return false
+ }
+
+ if service.CompareVersions(clientVersion, minVersion) < 0 {
+ h.errorResponse(c, http.StatusBadRequest, "invalid_request_error",
+ fmt.Sprintf("Your Claude Code version (%s) is below the minimum required version (%s). Please update: npm update -g @anthropic-ai/claude-code",
+ clientVersion, minVersion))
+ return false
+ }
+
+ return true
+}
+
// errorResponse 返回Claude API格式的错误响应
func (h *GatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) {
c.JSON(status, gin.H{
@@ -1024,9 +1079,10 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID),
)
+ defer h.maybeLogCompatibilityFallbackMetrics(reqLog)
// 读取请求体
- body, err := io.ReadAll(c.Request.Body)
+ body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
if err != nil {
if maxErr, ok := extractMaxBytesError(err); ok {
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
@@ -1041,9 +1097,6 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
return
}
- // 检查是否为 Claude Code 客户端,设置到 context 中
- SetClaudeCodeClientContext(c, body)
-
setOpsRequestContext(c, "", false, body)
parsedReq, err := service.ParseGatewayRequest(body, domain.PlatformAnthropic)
@@ -1051,9 +1104,11 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return
}
+ // count_tokens 走 messages 严格校验时,复用已解析请求,避免二次反序列化。
+ SetClaudeCodeClientContext(c, body, parsedReq)
reqLog = reqLog.With(zap.String("model", parsedReq.Model), zap.Bool("stream", parsedReq.Stream))
// 在请求上下文中记录 thinking 状态,供 Antigravity 最终模型 key 推导/模型维度限流使用
- c.Request = c.Request.WithContext(context.WithValue(c.Request.Context(), ctxkey.ThinkingEnabled, parsedReq.ThinkingEnabled))
+ c.Request = c.Request.WithContext(service.WithThinkingEnabled(c.Request.Context(), parsedReq.ThinkingEnabled, h.metadataBridgeEnabled()))
// 验证 model 必填
if parsedReq.Model == "" {
@@ -1217,24 +1272,8 @@ func sendMockInterceptStream(c *gin.Context, model string, interceptType Interce
textDeltas = []string{"New", " Conversation"}
}
- // Build message_start event with proper JSON marshaling
- messageStart := map[string]any{
- "type": "message_start",
- "message": map[string]any{
- "id": msgID,
- "type": "message",
- "role": "assistant",
- "model": model,
- "content": []any{},
- "stop_reason": nil,
- "stop_sequence": nil,
- "usage": map[string]int{
- "input_tokens": 10,
- "output_tokens": 0,
- },
- },
- }
- messageStartJSON, _ := json.Marshal(messageStart)
+ // Build message_start event with fixed schema.
+ messageStartJSON := `{"type":"message_start","message":{"id":` + strconv.Quote(msgID) + `,"type":"message","role":"assistant","model":` + strconv.Quote(model) + `,"content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":10,"output_tokens":0}}}`
// Build events
events := []string{
@@ -1244,31 +1283,12 @@ func sendMockInterceptStream(c *gin.Context, model string, interceptType Interce
// Add text deltas
for _, text := range textDeltas {
- delta := map[string]any{
- "type": "content_block_delta",
- "index": 0,
- "delta": map[string]string{
- "type": "text_delta",
- "text": text,
- },
- }
- deltaJSON, _ := json.Marshal(delta)
+ deltaJSON := `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":` + strconv.Quote(text) + `}}`
events = append(events, `event: content_block_delta`+"\n"+`data: `+string(deltaJSON))
}
// Add final events
- messageDelta := map[string]any{
- "type": "message_delta",
- "delta": map[string]any{
- "stop_reason": "end_turn",
- "stop_sequence": nil,
- },
- "usage": map[string]int{
- "input_tokens": 10,
- "output_tokens": outputTokens,
- },
- }
- messageDeltaJSON, _ := json.Marshal(messageDelta)
+ messageDeltaJSON := `{"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"input_tokens":10,"output_tokens":` + strconv.Itoa(outputTokens) + `}}`
events = append(events,
`event: content_block_stop`+"\n"+`data: {"index":0,"type":"content_block_stop"}`,
@@ -1366,6 +1386,30 @@ func billingErrorDetails(err error) (status int, code, message string) {
return http.StatusForbidden, "billing_error", msg
}
+func (h *GatewayHandler) metadataBridgeEnabled() bool {
+ if h == nil || h.cfg == nil {
+ return true
+ }
+ return h.cfg.Gateway.OpenAIWS.MetadataBridgeEnabled
+}
+
+func (h *GatewayHandler) maybeLogCompatibilityFallbackMetrics(reqLog *zap.Logger) {
+ if reqLog == nil {
+ return
+ }
+ if gatewayCompatibilityMetricsLogCounter.Add(1)%gatewayCompatibilityMetricsLogInterval != 0 {
+ return
+ }
+ metrics := service.SnapshotOpenAICompatibilityFallbackMetrics()
+ reqLog.Info("gateway.compatibility_fallback_metrics",
+ zap.Int64("session_hash_legacy_read_fallback_total", metrics.SessionHashLegacyReadFallbackTotal),
+ zap.Int64("session_hash_legacy_read_fallback_hit", metrics.SessionHashLegacyReadFallbackHit),
+ zap.Int64("session_hash_legacy_dual_write_total", metrics.SessionHashLegacyDualWriteTotal),
+ zap.Float64("session_hash_legacy_read_hit_rate", metrics.SessionHashLegacyReadHitRate),
+ zap.Int64("metadata_legacy_fallback_total", metrics.MetadataLegacyFallbackTotal),
+ )
+}
+
func (h *GatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) {
if task == nil {
return
@@ -1377,5 +1421,13 @@ func (h *GatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) {
// 回退路径:worker 池未注入时同步执行,避免退回到无界 goroutine 模式。
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
+ defer func() {
+ if recovered := recover(); recovered != nil {
+ logger.L().With(
+ zap.String("component", "handler.gateway.messages"),
+ zap.Any("panic", recovered),
+ ).Error("gateway.usage_record_task_panic_recovered")
+ }
+ }()
task(ctx)
}
diff --git a/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go b/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go
index 15d85949..2afa6440 100644
--- a/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go
+++ b/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go
@@ -119,6 +119,13 @@ func (f *fakeConcurrencyCache) GetAccountsLoadBatch(context.Context, []service.A
func (f *fakeConcurrencyCache) GetUsersLoadBatch(context.Context, []service.UserWithConcurrency) (map[int64]*service.UserLoadInfo, error) {
return map[int64]*service.UserLoadInfo{}, nil
}
+func (f *fakeConcurrencyCache) GetAccountConcurrencyBatch(_ context.Context, accountIDs []int64) (map[int64]int, error) {
+ result := make(map[int64]int, len(accountIDs))
+ for _, id := range accountIDs {
+ result[id] = 0
+ }
+ return result, nil
+}
func (f *fakeConcurrencyCache) CleanupExpiredAccountSlots(context.Context, int64) error { return nil }
func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*service.Account) (*GatewayHandler, func()) {
@@ -146,6 +153,7 @@ func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*servi
nil, // deferredService
nil, // claudeTokenProvider
nil, // sessionLimitCache
+ nil, // rpmCache
nil, // digestStore
)
diff --git a/backend/internal/handler/gateway_helper.go b/backend/internal/handler/gateway_helper.go
index efff7997..09e6c09b 100644
--- a/backend/internal/handler/gateway_helper.go
+++ b/backend/internal/handler/gateway_helper.go
@@ -18,14 +18,21 @@ import (
// claudeCodeValidator is a singleton validator for Claude Code client detection
var claudeCodeValidator = service.NewClaudeCodeValidator()
+const claudeCodeParsedRequestContextKey = "claude_code_parsed_request"
+
// SetClaudeCodeClientContext 检查请求是否来自 Claude Code 客户端,并设置到 context 中
// 返回更新后的 context
-func SetClaudeCodeClientContext(c *gin.Context, body []byte) {
+func SetClaudeCodeClientContext(c *gin.Context, body []byte, parsedReq *service.ParsedRequest) {
if c == nil || c.Request == nil {
return
}
+ if parsedReq != nil {
+ c.Set(claudeCodeParsedRequestContextKey, parsedReq)
+ }
+
+ ua := c.GetHeader("User-Agent")
// Fast path:非 Claude CLI UA 直接判定 false,避免热路径二次 JSON 反序列化。
- if !claudeCodeValidator.ValidateUserAgent(c.GetHeader("User-Agent")) {
+ if !claudeCodeValidator.ValidateUserAgent(ua) {
ctx := service.SetClaudeCodeClient(c.Request.Context(), false)
c.Request = c.Request.WithContext(ctx)
return
@@ -37,8 +44,11 @@ func SetClaudeCodeClientContext(c *gin.Context, body []byte) {
isClaudeCode = true
} else {
// 仅在确认为 Claude CLI 且 messages 路径时再做 body 解析。
- var bodyMap map[string]any
- if len(body) > 0 {
+ bodyMap := claudeCodeBodyMapFromParsedRequest(parsedReq)
+ if bodyMap == nil {
+ bodyMap = claudeCodeBodyMapFromContextCache(c)
+ }
+ if bodyMap == nil && len(body) > 0 {
_ = json.Unmarshal(body, &bodyMap)
}
isClaudeCode = claudeCodeValidator.Validate(c.Request, bodyMap)
@@ -46,9 +56,53 @@ func SetClaudeCodeClientContext(c *gin.Context, body []byte) {
// 更新 request context
ctx := service.SetClaudeCodeClient(c.Request.Context(), isClaudeCode)
+
+ // 仅在确认为 Claude Code 客户端时提取版本号写入 context
+ if isClaudeCode {
+ if version := claudeCodeValidator.ExtractVersion(ua); version != "" {
+ ctx = service.SetClaudeCodeVersion(ctx, version)
+ }
+ }
+
c.Request = c.Request.WithContext(ctx)
}
+func claudeCodeBodyMapFromParsedRequest(parsedReq *service.ParsedRequest) map[string]any {
+ if parsedReq == nil {
+ return nil
+ }
+ bodyMap := map[string]any{
+ "model": parsedReq.Model,
+ }
+ if parsedReq.System != nil || parsedReq.HasSystem {
+ bodyMap["system"] = parsedReq.System
+ }
+ if parsedReq.MetadataUserID != "" {
+ bodyMap["metadata"] = map[string]any{"user_id": parsedReq.MetadataUserID}
+ }
+ return bodyMap
+}
+
+func claudeCodeBodyMapFromContextCache(c *gin.Context) map[string]any {
+ if c == nil {
+ return nil
+ }
+ if cached, ok := c.Get(service.OpenAIParsedRequestBodyKey); ok {
+ if bodyMap, ok := cached.(map[string]any); ok {
+ return bodyMap
+ }
+ }
+ if cached, ok := c.Get(claudeCodeParsedRequestContextKey); ok {
+ switch v := cached.(type) {
+ case *service.ParsedRequest:
+ return claudeCodeBodyMapFromParsedRequest(v)
+ case service.ParsedRequest:
+ return claudeCodeBodyMapFromParsedRequest(&v)
+ }
+ }
+ return nil
+}
+
// 并发槽位等待相关常量
//
// 性能优化说明:
diff --git a/backend/internal/handler/gateway_helper_fastpath_test.go b/backend/internal/handler/gateway_helper_fastpath_test.go
index 3e6c376b..31d489f0 100644
--- a/backend/internal/handler/gateway_helper_fastpath_test.go
+++ b/backend/internal/handler/gateway_helper_fastpath_test.go
@@ -33,6 +33,14 @@ func (m *concurrencyCacheMock) GetAccountConcurrency(ctx context.Context, accoun
return 0, nil
}
+func (m *concurrencyCacheMock) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {
+ result := make(map[int64]int, len(accountIDs))
+ for _, accountID := range accountIDs {
+ result[accountID] = 0
+ }
+ return result, nil
+}
+
func (m *concurrencyCacheMock) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
return true, nil
}
diff --git a/backend/internal/handler/gateway_helper_hotpath_test.go b/backend/internal/handler/gateway_helper_hotpath_test.go
index 3fdf1bfc..f8f7eaca 100644
--- a/backend/internal/handler/gateway_helper_hotpath_test.go
+++ b/backend/internal/handler/gateway_helper_hotpath_test.go
@@ -49,6 +49,14 @@ func (s *helperConcurrencyCacheStub) GetAccountConcurrency(ctx context.Context,
return 0, nil
}
+func (s *helperConcurrencyCacheStub) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {
+ out := make(map[int64]int, len(accountIDs))
+ for _, accountID := range accountIDs {
+ out[accountID] = 0
+ }
+ return out, nil
+}
+
func (s *helperConcurrencyCacheStub) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
return true, nil
}
@@ -133,7 +141,7 @@ func TestSetClaudeCodeClientContext_FastPathAndStrictPath(t *testing.T) {
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
c.Request.Header.Set("User-Agent", "curl/8.6.0")
- SetClaudeCodeClientContext(c, validClaudeCodeBodyJSON())
+ SetClaudeCodeClientContext(c, validClaudeCodeBodyJSON(), nil)
require.False(t, service.IsClaudeCodeClient(c.Request.Context()))
})
@@ -141,7 +149,7 @@ func TestSetClaudeCodeClientContext_FastPathAndStrictPath(t *testing.T) {
c, _ := newHelperTestContext(http.MethodGet, "/v1/models")
c.Request.Header.Set("User-Agent", "claude-cli/1.0.1")
- SetClaudeCodeClientContext(c, nil)
+ SetClaudeCodeClientContext(c, nil, nil)
require.True(t, service.IsClaudeCodeClient(c.Request.Context()))
})
@@ -152,7 +160,7 @@ func TestSetClaudeCodeClientContext_FastPathAndStrictPath(t *testing.T) {
c.Request.Header.Set("anthropic-beta", "message-batches-2024-09-24")
c.Request.Header.Set("anthropic-version", "2023-06-01")
- SetClaudeCodeClientContext(c, validClaudeCodeBodyJSON())
+ SetClaudeCodeClientContext(c, validClaudeCodeBodyJSON(), nil)
require.True(t, service.IsClaudeCodeClient(c.Request.Context()))
})
@@ -160,11 +168,51 @@ func TestSetClaudeCodeClientContext_FastPathAndStrictPath(t *testing.T) {
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
c.Request.Header.Set("User-Agent", "claude-cli/1.0.1")
// 缺少严格校验所需 header + body 字段
- SetClaudeCodeClientContext(c, []byte(`{"model":"x"}`))
+ SetClaudeCodeClientContext(c, []byte(`{"model":"x"}`), nil)
require.False(t, service.IsClaudeCodeClient(c.Request.Context()))
})
}
+func TestSetClaudeCodeClientContext_ReuseParsedRequestAndContextCache(t *testing.T) {
+ t.Run("reuse parsed request without body unmarshal", func(t *testing.T) {
+ c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
+ c.Request.Header.Set("User-Agent", "claude-cli/1.0.1")
+ c.Request.Header.Set("X-App", "claude-code")
+ c.Request.Header.Set("anthropic-beta", "message-batches-2024-09-24")
+ c.Request.Header.Set("anthropic-version", "2023-06-01")
+
+ parsedReq := &service.ParsedRequest{
+ Model: "claude-3-5-sonnet-20241022",
+ System: []any{
+ map[string]any{"text": "You are Claude Code, Anthropic's official CLI for Claude."},
+ },
+ MetadataUserID: "user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_abc-123",
+ }
+
+ // body 非法 JSON,如果函数复用 parsedReq 成功则仍应判定为 Claude Code。
+ SetClaudeCodeClientContext(c, []byte(`{invalid`), parsedReq)
+ require.True(t, service.IsClaudeCodeClient(c.Request.Context()))
+ })
+
+ t.Run("reuse context cache without body unmarshal", func(t *testing.T) {
+ c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
+ c.Request.Header.Set("User-Agent", "claude-cli/1.0.1")
+ c.Request.Header.Set("X-App", "claude-code")
+ c.Request.Header.Set("anthropic-beta", "message-batches-2024-09-24")
+ c.Request.Header.Set("anthropic-version", "2023-06-01")
+ c.Set(service.OpenAIParsedRequestBodyKey, map[string]any{
+ "model": "claude-3-5-sonnet-20241022",
+ "system": []any{
+ map[string]any{"text": "You are Claude Code, Anthropic's official CLI for Claude."},
+ },
+ "metadata": map[string]any{"user_id": "user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_abc-123"},
+ })
+
+ SetClaudeCodeClientContext(c, []byte(`{invalid`), nil)
+ require.True(t, service.IsClaudeCodeClient(c.Request.Context()))
+ })
+}
+
func TestWaitForSlotWithPingTimeout_AccountAndUserAcquire(t *testing.T) {
cache := &helperConcurrencyCacheStub{
accountSeq: []bool{false, true},
diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go
index 2da0570b..50af9c8f 100644
--- a/backend/internal/handler/gemini_v1beta_handler.go
+++ b/backend/internal/handler/gemini_v1beta_handler.go
@@ -7,16 +7,15 @@ import (
"encoding/hex"
"encoding/json"
"errors"
- "io"
"net/http"
"regexp"
"strings"
"github.com/Wei-Shaw/sub2api/internal/domain"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
- "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/gemini"
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
+ pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
@@ -168,7 +167,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
stream := action == "streamGenerateContent"
reqLog = reqLog.With(zap.String("model", modelName), zap.String("action", action), zap.Bool("stream", stream))
- body, err := io.ReadAll(c.Request.Body)
+ body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
if err != nil {
if maxErr, ok := extractMaxBytesError(err); ok {
googleError(c, http.StatusRequestEntityTooLarge, buildBodyTooLargeMessage(maxErr.Limit))
@@ -268,8 +267,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
if apiKey.GroupID != nil {
prefetchedGroupID = *apiKey.GroupID
}
- ctx := context.WithValue(c.Request.Context(), ctxkey.PrefetchedStickyAccountID, sessionBoundAccountID)
- ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyGroupID, prefetchedGroupID)
+ ctx := service.WithPrefetchedStickySession(c.Request.Context(), sessionBoundAccountID, prefetchedGroupID, h.metadataBridgeEnabled())
c.Request = c.Request.WithContext(ctx)
}
}
@@ -349,7 +347,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
// 单账号分组提前设置 SingleAccountRetry 标记,让 Service 层首次 503 就不设模型限流标记。
// 避免单账号分组收到 503 (MODEL_CAPACITY_EXHAUSTED) 时设 29s 限流,导致后续请求连续快速失败。
if h.gatewayService.IsSingleAntigravityAccountGroup(c.Request.Context(), apiKey.GroupID) {
- ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
+ ctx := service.WithSingleAccountRetry(c.Request.Context(), true, h.metadataBridgeEnabled())
c.Request = c.Request.WithContext(ctx)
}
@@ -363,7 +361,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
action := fs.HandleSelectionExhausted(c.Request.Context())
switch action {
case FailoverContinue:
- ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
+ ctx := service.WithSingleAccountRetry(c.Request.Context(), true, h.metadataBridgeEnabled())
c.Request = c.Request.WithContext(ctx)
continue
case FailoverCanceled:
@@ -456,7 +454,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
var result *service.ForwardResult
requestCtx := c.Request.Context()
if fs.SwitchCount > 0 {
- requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, fs.SwitchCount)
+ requestCtx = service.WithAccountSwitchCount(requestCtx, fs.SwitchCount, h.metadataBridgeEnabled())
}
if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey {
result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, modelName, action, stream, body, hasBoundSession)
diff --git a/backend/internal/handler/handler.go b/backend/internal/handler/handler.go
index b999180b..1e1247fc 100644
--- a/backend/internal/handler/handler.go
+++ b/backend/internal/handler/handler.go
@@ -11,6 +11,7 @@ type AdminHandlers struct {
Group *admin.GroupHandler
Account *admin.AccountHandler
Announcement *admin.AnnouncementHandler
+ DataManagement *admin.DataManagementHandler
OAuth *admin.OAuthHandler
OpenAIOAuth *admin.OpenAIOAuthHandler
GeminiOAuth *admin.GeminiOAuthHandler
@@ -25,6 +26,7 @@ type AdminHandlers struct {
Usage *admin.UsageHandler
UserAttribute *admin.UserAttributeHandler
ErrorPassthrough *admin.ErrorPassthroughHandler
+ APIKey *admin.AdminAPIKeyHandler
}
// Handlers contains all HTTP handlers
@@ -40,6 +42,7 @@ type Handlers struct {
Gateway *GatewayHandler
OpenAIGateway *OpenAIGatewayHandler
SoraGateway *SoraGatewayHandler
+ SoraClient *SoraClientHandler
Setting *SettingHandler
Totp *TotpHandler
}
diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go
index 50af684d..4bbd17ba 100644
--- a/backend/internal/handler/openai_gateway_handler.go
+++ b/backend/internal/handler/openai_gateway_handler.go
@@ -5,17 +5,20 @@ import (
"encoding/json"
"errors"
"fmt"
- "io"
"net/http"
+ "runtime/debug"
+ "strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
+ pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
+ coderws "github.com/coder/websocket"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
"go.uber.org/zap"
@@ -64,6 +67,11 @@ func NewOpenAIGatewayHandler(
// Responses handles OpenAI Responses API endpoint
// POST /openai/v1/responses
func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
+ // 局部兜底:确保该 handler 内部任何 panic 都不会击穿到进程级。
+ streamStarted := false
+ defer h.recoverResponsesPanic(c, &streamStarted)
+ setOpenAIClientTransportHTTP(c)
+
requestStart := time.Now()
// Get apiKey and user from context (set by ApiKeyAuth middleware)
@@ -85,9 +93,12 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID),
)
+ if !h.ensureResponsesDependencies(c, reqLog) {
+ return
+ }
// Read request body
- body, err := io.ReadAll(c.Request.Body)
+ body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
if err != nil {
if maxErr, ok := extractMaxBytesError(err); ok {
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
@@ -125,43 +136,30 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
}
reqStream := streamResult.Bool()
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
+ previousResponseID := strings.TrimSpace(gjson.GetBytes(body, "previous_response_id").String())
+ if previousResponseID != "" {
+ previousResponseIDKind := service.ClassifyOpenAIPreviousResponseIDKind(previousResponseID)
+ reqLog = reqLog.With(
+ zap.Bool("has_previous_response_id", true),
+ zap.String("previous_response_id_kind", previousResponseIDKind),
+ zap.Int("previous_response_id_len", len(previousResponseID)),
+ )
+ if previousResponseIDKind == service.OpenAIPreviousResponseIDKindMessageID {
+ reqLog.Warn("openai.request_validation_failed",
+ zap.String("reason", "previous_response_id_looks_like_message_id"),
+ )
+ h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "previous_response_id must be a response.id (resp_*), not a message id")
+ return
+ }
+ }
setOpsRequestContext(c, reqModel, reqStream, body)
// 提前校验 function_call_output 是否具备可关联上下文,避免上游 400。
- // 要求 previous_response_id,或 input 内存在带 call_id 的 tool_call/function_call,
- // 或带 id 且与 call_id 匹配的 item_reference。
- // 此路径需要遍历 input 数组做 call_id 关联检查,保留 Unmarshal
- if gjson.GetBytes(body, `input.#(type=="function_call_output")`).Exists() {
- var reqBody map[string]any
- if err := json.Unmarshal(body, &reqBody); err == nil {
- c.Set(service.OpenAIParsedRequestBodyKey, reqBody)
- if service.HasFunctionCallOutput(reqBody) {
- previousResponseID, _ := reqBody["previous_response_id"].(string)
- if strings.TrimSpace(previousResponseID) == "" && !service.HasToolCallContext(reqBody) {
- if service.HasFunctionCallOutputMissingCallID(reqBody) {
- reqLog.Warn("openai.request_validation_failed",
- zap.String("reason", "function_call_output_missing_call_id"),
- )
- h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires call_id or previous_response_id; if relying on history, ensure store=true and reuse previous_response_id")
- return
- }
- callIDs := service.FunctionCallOutputCallIDs(reqBody)
- if !service.HasItemReferenceForCallIDs(reqBody, callIDs) {
- reqLog.Warn("openai.request_validation_failed",
- zap.String("reason", "function_call_output_missing_item_reference"),
- )
- h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires item_reference ids matching each call_id, or previous_response_id/tool_call context; if relying on history, ensure store=true and reuse previous_response_id")
- return
- }
- }
- }
- }
+ if !h.validateFunctionCallOutputRequest(c, body, reqLog) {
+ return
}
- // Track if we've started streaming (for error handling)
- streamStarted := false
-
// 绑定错误透传服务,允许 service 层在非 failover 错误场景复用规则。
if h.errorPassthroughService != nil {
service.BindErrorPassthroughService(c, h.errorPassthroughService)
@@ -173,51 +171,11 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds())
routingStart := time.Now()
- // 0. 先尝试直接抢占用户槽位(快速路径)
- userReleaseFunc, userAcquired, err := h.concurrencyHelper.TryAcquireUserSlot(c.Request.Context(), subject.UserID, subject.Concurrency)
- if err != nil {
- reqLog.Warn("openai.user_slot_acquire_failed", zap.Error(err))
- h.handleConcurrencyError(c, err, "user", streamStarted)
+ userReleaseFunc, acquired := h.acquireResponsesUserSlot(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted, reqLog)
+ if !acquired {
return
}
-
- waitCounted := false
- if !userAcquired {
- // 仅在抢槽失败时才进入等待队列,减少常态请求 Redis 写入。
- maxWait := service.CalculateMaxWait(subject.Concurrency)
- canWait, waitErr := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
- if waitErr != nil {
- reqLog.Warn("openai.user_wait_counter_increment_failed", zap.Error(waitErr))
- // 按现有降级语义:等待计数异常时放行后续抢槽流程
- } else if !canWait {
- reqLog.Info("openai.user_wait_queue_full", zap.Int("max_wait", maxWait))
- h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
- return
- }
- if waitErr == nil && canWait {
- waitCounted = true
- }
- defer func() {
- if waitCounted {
- h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
- }
- }()
-
- userReleaseFunc, err = h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted)
- if err != nil {
- reqLog.Warn("openai.user_slot_acquire_failed_after_wait", zap.Error(err))
- h.handleConcurrencyError(c, err, "user", streamStarted)
- return
- }
- }
-
- // 用户槽位已获取:退出等待队列计数。
- if waitCounted {
- h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
- waitCounted = false
- }
// 确保请求取消时也会释放槽位,避免长连接被动中断造成泄漏
- userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc)
if userReleaseFunc != nil {
defer userReleaseFunc()
}
@@ -241,7 +199,15 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
for {
// Select account supporting the requested model
reqLog.Debug("openai.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs)))
- selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs)
+ selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler(
+ c.Request.Context(),
+ apiKey.GroupID,
+ previousResponseID,
+ sessionHash,
+ reqModel,
+ failedAccountIDs,
+ service.OpenAIUpstreamTransportAny,
+ )
if err != nil {
reqLog.Warn("openai.account_select_failed",
zap.Error(err),
@@ -258,80 +224,30 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
}
return
}
+ if selection == nil || selection.Account == nil {
+ h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
+ return
+ }
+ if previousResponseID != "" && selection != nil && selection.Account != nil {
+ reqLog.Debug("openai.account_selected_with_previous_response_id", zap.Int64("account_id", selection.Account.ID))
+ }
+ reqLog.Debug("openai.account_schedule_decision",
+ zap.String("layer", scheduleDecision.Layer),
+ zap.Bool("sticky_previous_hit", scheduleDecision.StickyPreviousHit),
+ zap.Bool("sticky_session_hit", scheduleDecision.StickySessionHit),
+ zap.Int("candidate_count", scheduleDecision.CandidateCount),
+ zap.Int("top_k", scheduleDecision.TopK),
+ zap.Int64("latency_ms", scheduleDecision.LatencyMs),
+ zap.Float64("load_skew", scheduleDecision.LoadSkew),
+ )
account := selection.Account
reqLog.Debug("openai.account_selected", zap.Int64("account_id", account.ID), zap.String("account_name", account.Name))
setOpsSelectedAccount(c, account.ID, account.Platform)
- // 3. Acquire account concurrency slot
- accountReleaseFunc := selection.ReleaseFunc
- if !selection.Acquired {
- if selection.WaitPlan == nil {
- h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
- return
- }
-
- // 先快速尝试一次账号槽位,命中则跳过等待计数写入。
- fastReleaseFunc, fastAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot(
- c.Request.Context(),
- account.ID,
- selection.WaitPlan.MaxConcurrency,
- )
- if err != nil {
- reqLog.Warn("openai.account_slot_quick_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
- h.handleConcurrencyError(c, err, "account", streamStarted)
- return
- }
- if fastAcquired {
- accountReleaseFunc = fastReleaseFunc
- if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionHash, account.ID); err != nil {
- reqLog.Warn("openai.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
- }
- } else {
- accountWaitCounted := false
- canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
- if err != nil {
- reqLog.Warn("openai.account_wait_counter_increment_failed", zap.Int64("account_id", account.ID), zap.Error(err))
- } else if !canWait {
- reqLog.Info("openai.account_wait_queue_full",
- zap.Int64("account_id", account.ID),
- zap.Int("max_waiting", selection.WaitPlan.MaxWaiting),
- )
- h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
- return
- }
- if err == nil && canWait {
- accountWaitCounted = true
- }
- releaseWait := func() {
- if accountWaitCounted {
- h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
- accountWaitCounted = false
- }
- }
-
- accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
- c,
- account.ID,
- selection.WaitPlan.MaxConcurrency,
- selection.WaitPlan.Timeout,
- reqStream,
- &streamStarted,
- )
- if err != nil {
- reqLog.Warn("openai.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
- releaseWait()
- h.handleConcurrencyError(c, err, "account", streamStarted)
- return
- }
- // Slot acquired: no longer waiting in queue.
- releaseWait()
- if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionHash, account.ID); err != nil {
- reqLog.Warn("openai.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
- }
- }
+ accountReleaseFunc, acquired := h.acquireResponsesAccountSlot(c, apiKey.GroupID, sessionHash, selection, reqStream, &streamStarted, reqLog)
+ if !acquired {
+ return
}
- // 账号槽位/等待计数需要在超时或断开时安全回收
- accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
// Forward request
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
@@ -353,6 +269,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
if err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
+ h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
+ h.gatewayService.RecordOpenAIAccountSwitch()
failedAccountIDs[account.ID] = struct{}{}
lastFailoverErr = failoverErr
if switchCount >= maxAccountSwitches {
@@ -368,14 +286,25 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
)
continue
}
+ h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
- reqLog.Error("openai.forward_failed",
+ fields := []zap.Field{
zap.Int64("account_id", account.ID),
zap.Bool("fallback_error_response_written", wroteFallback),
zap.Error(err),
- )
+ }
+ if shouldLogOpenAIForwardFailureAsWarn(c, wroteFallback) {
+ reqLog.Warn("openai.forward_failed", fields...)
+ return
+ }
+ reqLog.Error("openai.forward_failed", fields...)
return
}
+ if result != nil {
+ h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
+ } else {
+ h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, nil)
+ }
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
userAgent := c.GetHeader("User-Agent")
@@ -411,6 +340,525 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
}
}
+func (h *OpenAIGatewayHandler) validateFunctionCallOutputRequest(c *gin.Context, body []byte, reqLog *zap.Logger) bool {
+ if !gjson.GetBytes(body, `input.#(type=="function_call_output")`).Exists() {
+ return true
+ }
+
+ var reqBody map[string]any
+ if err := json.Unmarshal(body, &reqBody); err != nil {
+ // 保持原有容错语义:解析失败时跳过预校验,沿用后续上游校验结果。
+ return true
+ }
+
+ c.Set(service.OpenAIParsedRequestBodyKey, reqBody)
+ validation := service.ValidateFunctionCallOutputContext(reqBody)
+ if !validation.HasFunctionCallOutput {
+ return true
+ }
+
+ previousResponseID, _ := reqBody["previous_response_id"].(string)
+ if strings.TrimSpace(previousResponseID) != "" || validation.HasToolCallContext {
+ return true
+ }
+
+ if validation.HasFunctionCallOutputMissingCallID {
+ reqLog.Warn("openai.request_validation_failed",
+ zap.String("reason", "function_call_output_missing_call_id"),
+ )
+ h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires call_id or previous_response_id; if relying on history, ensure store=true and reuse previous_response_id")
+ return false
+ }
+ if validation.HasItemReferenceForAllCallIDs {
+ return true
+ }
+
+ reqLog.Warn("openai.request_validation_failed",
+ zap.String("reason", "function_call_output_missing_item_reference"),
+ )
+ h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires item_reference ids matching each call_id, or previous_response_id/tool_call context; if relying on history, ensure store=true and reuse previous_response_id")
+ return false
+}
+
+func (h *OpenAIGatewayHandler) acquireResponsesUserSlot(
+ c *gin.Context,
+ userID int64,
+ userConcurrency int,
+ reqStream bool,
+ streamStarted *bool,
+ reqLog *zap.Logger,
+) (func(), bool) {
+ ctx := c.Request.Context()
+ userReleaseFunc, userAcquired, err := h.concurrencyHelper.TryAcquireUserSlot(ctx, userID, userConcurrency)
+ if err != nil {
+ reqLog.Warn("openai.user_slot_acquire_failed", zap.Error(err))
+ h.handleConcurrencyError(c, err, "user", *streamStarted)
+ return nil, false
+ }
+ if userAcquired {
+ return wrapReleaseOnDone(ctx, userReleaseFunc), true
+ }
+
+ maxWait := service.CalculateMaxWait(userConcurrency)
+ canWait, waitErr := h.concurrencyHelper.IncrementWaitCount(ctx, userID, maxWait)
+ if waitErr != nil {
+ reqLog.Warn("openai.user_wait_counter_increment_failed", zap.Error(waitErr))
+ // 按现有降级语义:等待计数异常时放行后续抢槽流程
+ } else if !canWait {
+ reqLog.Info("openai.user_wait_queue_full", zap.Int("max_wait", maxWait))
+ h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
+ return nil, false
+ }
+
+ waitCounted := waitErr == nil && canWait
+ defer func() {
+ if waitCounted {
+ h.concurrencyHelper.DecrementWaitCount(ctx, userID)
+ }
+ }()
+
+ userReleaseFunc, err = h.concurrencyHelper.AcquireUserSlotWithWait(c, userID, userConcurrency, reqStream, streamStarted)
+ if err != nil {
+ reqLog.Warn("openai.user_slot_acquire_failed_after_wait", zap.Error(err))
+ h.handleConcurrencyError(c, err, "user", *streamStarted)
+ return nil, false
+ }
+
+ // 槽位获取成功后,立刻退出等待计数。
+ if waitCounted {
+ h.concurrencyHelper.DecrementWaitCount(ctx, userID)
+ waitCounted = false
+ }
+ return wrapReleaseOnDone(ctx, userReleaseFunc), true
+}
+
+func (h *OpenAIGatewayHandler) acquireResponsesAccountSlot(
+ c *gin.Context,
+ groupID *int64,
+ sessionHash string,
+ selection *service.AccountSelectionResult,
+ reqStream bool,
+ streamStarted *bool,
+ reqLog *zap.Logger,
+) (func(), bool) {
+ if selection == nil || selection.Account == nil {
+ h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", *streamStarted)
+ return nil, false
+ }
+
+ ctx := c.Request.Context()
+ account := selection.Account
+ if selection.Acquired {
+ return wrapReleaseOnDone(ctx, selection.ReleaseFunc), true
+ }
+ if selection.WaitPlan == nil {
+ h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", *streamStarted)
+ return nil, false
+ }
+
+ fastReleaseFunc, fastAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot(
+ ctx,
+ account.ID,
+ selection.WaitPlan.MaxConcurrency,
+ )
+ if err != nil {
+ reqLog.Warn("openai.account_slot_quick_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
+ h.handleConcurrencyError(c, err, "account", *streamStarted)
+ return nil, false
+ }
+ if fastAcquired {
+ if err := h.gatewayService.BindStickySession(ctx, groupID, sessionHash, account.ID); err != nil {
+ reqLog.Warn("openai.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
+ }
+ return wrapReleaseOnDone(ctx, fastReleaseFunc), true
+ }
+
+ canWait, waitErr := h.concurrencyHelper.IncrementAccountWaitCount(ctx, account.ID, selection.WaitPlan.MaxWaiting)
+ if waitErr != nil {
+ reqLog.Warn("openai.account_wait_counter_increment_failed", zap.Int64("account_id", account.ID), zap.Error(waitErr))
+ } else if !canWait {
+ reqLog.Info("openai.account_wait_queue_full",
+ zap.Int64("account_id", account.ID),
+ zap.Int("max_waiting", selection.WaitPlan.MaxWaiting),
+ )
+ h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", *streamStarted)
+ return nil, false
+ }
+
+ accountWaitCounted := waitErr == nil && canWait
+ releaseWait := func() {
+ if accountWaitCounted {
+ h.concurrencyHelper.DecrementAccountWaitCount(ctx, account.ID)
+ accountWaitCounted = false
+ }
+ }
+ defer releaseWait()
+
+ accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
+ c,
+ account.ID,
+ selection.WaitPlan.MaxConcurrency,
+ selection.WaitPlan.Timeout,
+ reqStream,
+ streamStarted,
+ )
+ if err != nil {
+ reqLog.Warn("openai.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
+ h.handleConcurrencyError(c, err, "account", *streamStarted)
+ return nil, false
+ }
+
+ // Slot acquired: no longer waiting in queue.
+ releaseWait()
+ if err := h.gatewayService.BindStickySession(ctx, groupID, sessionHash, account.ID); err != nil {
+ reqLog.Warn("openai.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
+ }
+ return wrapReleaseOnDone(ctx, accountReleaseFunc), true
+}
+
+// ResponsesWebSocket handles OpenAI Responses API WebSocket ingress endpoint
+// GET /openai/v1/responses (Upgrade: websocket)
+func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
+ if !isOpenAIWSUpgradeRequest(c.Request) {
+ h.errorResponse(c, http.StatusUpgradeRequired, "invalid_request_error", "WebSocket upgrade required (Upgrade: websocket)")
+ return
+ }
+ setOpenAIClientTransportWS(c)
+
+ apiKey, ok := middleware2.GetAPIKeyFromContext(c)
+ if !ok {
+ h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
+ return
+ }
+ subject, ok := middleware2.GetAuthSubjectFromContext(c)
+ if !ok {
+ h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
+ return
+ }
+
+ reqLog := requestLogger(
+ c,
+ "handler.openai_gateway.responses_ws",
+ zap.Int64("user_id", subject.UserID),
+ zap.Int64("api_key_id", apiKey.ID),
+ zap.Any("group_id", apiKey.GroupID),
+ zap.Bool("openai_ws_mode", true),
+ )
+ if !h.ensureResponsesDependencies(c, reqLog) {
+ return
+ }
+ reqLog.Info("openai.websocket_ingress_started")
+ clientIP := ip.GetClientIP(c)
+ userAgent := strings.TrimSpace(c.GetHeader("User-Agent"))
+
+ wsConn, err := coderws.Accept(c.Writer, c.Request, &coderws.AcceptOptions{
+ CompressionMode: coderws.CompressionContextTakeover,
+ })
+ if err != nil {
+ reqLog.Warn("openai.websocket_accept_failed",
+ zap.Error(err),
+ zap.String("client_ip", clientIP),
+ zap.String("request_user_agent", userAgent),
+ zap.String("upgrade_header", strings.TrimSpace(c.GetHeader("Upgrade"))),
+ zap.String("connection_header", strings.TrimSpace(c.GetHeader("Connection"))),
+ zap.String("sec_websocket_version", strings.TrimSpace(c.GetHeader("Sec-WebSocket-Version"))),
+ zap.Bool("has_sec_websocket_key", strings.TrimSpace(c.GetHeader("Sec-WebSocket-Key")) != ""),
+ )
+ return
+ }
+ defer func() {
+ _ = wsConn.CloseNow()
+ }()
+ wsConn.SetReadLimit(16 * 1024 * 1024)
+
+ ctx := c.Request.Context()
+ readCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
+ msgType, firstMessage, err := wsConn.Read(readCtx)
+ cancel()
+ if err != nil {
+ closeStatus, closeReason := summarizeWSCloseErrorForLog(err)
+ reqLog.Warn("openai.websocket_read_first_message_failed",
+ zap.Error(err),
+ zap.String("client_ip", clientIP),
+ zap.String("close_status", closeStatus),
+ zap.String("close_reason", closeReason),
+ zap.Duration("read_timeout", 30*time.Second),
+ )
+ closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "missing first response.create message")
+ return
+ }
+ if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
+ closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "unsupported websocket message type")
+ return
+ }
+ if !gjson.ValidBytes(firstMessage) {
+ closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "invalid JSON payload")
+ return
+ }
+
+ reqModel := strings.TrimSpace(gjson.GetBytes(firstMessage, "model").String())
+ if reqModel == "" {
+ closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "model is required in first response.create payload")
+ return
+ }
+ previousResponseID := strings.TrimSpace(gjson.GetBytes(firstMessage, "previous_response_id").String())
+ previousResponseIDKind := service.ClassifyOpenAIPreviousResponseIDKind(previousResponseID)
+ if previousResponseID != "" && previousResponseIDKind == service.OpenAIPreviousResponseIDKindMessageID {
+ closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "previous_response_id must be a response.id (resp_*), not a message id")
+ return
+ }
+ reqLog = reqLog.With(
+ zap.Bool("ws_ingress", true),
+ zap.String("model", reqModel),
+ zap.Bool("has_previous_response_id", previousResponseID != ""),
+ zap.String("previous_response_id_kind", previousResponseIDKind),
+ )
+ setOpsRequestContext(c, reqModel, true, firstMessage)
+
+ var currentUserRelease func()
+ var currentAccountRelease func()
+ releaseTurnSlots := func() {
+ if currentAccountRelease != nil {
+ currentAccountRelease()
+ currentAccountRelease = nil
+ }
+ if currentUserRelease != nil {
+ currentUserRelease()
+ currentUserRelease = nil
+ }
+ }
+ // 必须尽早注册,确保任何 early return 都能释放已获取的并发槽位。
+ defer releaseTurnSlots()
+
+ userReleaseFunc, userAcquired, err := h.concurrencyHelper.TryAcquireUserSlot(ctx, subject.UserID, subject.Concurrency)
+ if err != nil {
+ reqLog.Warn("openai.websocket_user_slot_acquire_failed", zap.Error(err))
+ closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to acquire user concurrency slot")
+ return
+ }
+ if !userAcquired {
+ closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "too many concurrent requests, please retry later")
+ return
+ }
+ currentUserRelease = wrapReleaseOnDone(ctx, userReleaseFunc)
+
+ subscription, _ := middleware2.GetSubscriptionFromContext(c)
+ if err := h.billingCacheService.CheckBillingEligibility(ctx, apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
+ reqLog.Info("openai.websocket_billing_eligibility_check_failed", zap.Error(err))
+ closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "billing check failed")
+ return
+ }
+
+ sessionHash := h.gatewayService.GenerateSessionHashWithFallback(
+ c,
+ firstMessage,
+ openAIWSIngressFallbackSessionSeed(subject.UserID, apiKey.ID, apiKey.GroupID),
+ )
+ selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler(
+ ctx,
+ apiKey.GroupID,
+ previousResponseID,
+ sessionHash,
+ reqModel,
+ nil,
+ service.OpenAIUpstreamTransportResponsesWebsocketV2,
+ )
+ if err != nil {
+ reqLog.Warn("openai.websocket_account_select_failed", zap.Error(err))
+ closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "no available account")
+ return
+ }
+ if selection == nil || selection.Account == nil {
+ closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "no available account")
+ return
+ }
+
+ account := selection.Account
+ accountMaxConcurrency := account.Concurrency
+ if selection.WaitPlan != nil && selection.WaitPlan.MaxConcurrency > 0 {
+ accountMaxConcurrency = selection.WaitPlan.MaxConcurrency
+ }
+ accountReleaseFunc := selection.ReleaseFunc
+ if !selection.Acquired {
+ if selection.WaitPlan == nil {
+ closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "account is busy, please retry later")
+ return
+ }
+ fastReleaseFunc, fastAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot(
+ ctx,
+ account.ID,
+ selection.WaitPlan.MaxConcurrency,
+ )
+ if err != nil {
+ reqLog.Warn("openai.websocket_account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
+ closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to acquire account concurrency slot")
+ return
+ }
+ if !fastAcquired {
+ closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "account is busy, please retry later")
+ return
+ }
+ accountReleaseFunc = fastReleaseFunc
+ }
+ currentAccountRelease = wrapReleaseOnDone(ctx, accountReleaseFunc)
+ if err := h.gatewayService.BindStickySession(ctx, apiKey.GroupID, sessionHash, account.ID); err != nil {
+ reqLog.Warn("openai.websocket_bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
+ }
+
+ token, _, err := h.gatewayService.GetAccessToken(ctx, account)
+ if err != nil {
+ reqLog.Warn("openai.websocket_get_access_token_failed", zap.Int64("account_id", account.ID), zap.Error(err))
+ closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to get access token")
+ return
+ }
+
+ reqLog.Debug("openai.websocket_account_selected",
+ zap.Int64("account_id", account.ID),
+ zap.String("account_name", account.Name),
+ zap.String("schedule_layer", scheduleDecision.Layer),
+ zap.Int("candidate_count", scheduleDecision.CandidateCount),
+ )
+
+ hooks := &service.OpenAIWSIngressHooks{
+ BeforeTurn: func(turn int) error {
+ if turn == 1 {
+ return nil
+ }
+ // 防御式清理:避免异常路径下旧槽位覆盖导致泄漏。
+ releaseTurnSlots()
+ // 非首轮 turn 需要重新抢占并发槽位,避免长连接空闲占槽。
+ userReleaseFunc, userAcquired, err := h.concurrencyHelper.TryAcquireUserSlot(ctx, subject.UserID, subject.Concurrency)
+ if err != nil {
+ return service.NewOpenAIWSClientCloseError(coderws.StatusInternalError, "failed to acquire user concurrency slot", err)
+ }
+ if !userAcquired {
+ return service.NewOpenAIWSClientCloseError(coderws.StatusTryAgainLater, "too many concurrent requests, please retry later", nil)
+ }
+ accountReleaseFunc, accountAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot(ctx, account.ID, accountMaxConcurrency)
+ if err != nil {
+ if userReleaseFunc != nil {
+ userReleaseFunc()
+ }
+ return service.NewOpenAIWSClientCloseError(coderws.StatusInternalError, "failed to acquire account concurrency slot", err)
+ }
+ if !accountAcquired {
+ if userReleaseFunc != nil {
+ userReleaseFunc()
+ }
+ return service.NewOpenAIWSClientCloseError(coderws.StatusTryAgainLater, "account is busy, please retry later", nil)
+ }
+ currentUserRelease = wrapReleaseOnDone(ctx, userReleaseFunc)
+ currentAccountRelease = wrapReleaseOnDone(ctx, accountReleaseFunc)
+ return nil
+ },
+ AfterTurn: func(turn int, result *service.OpenAIForwardResult, turnErr error) {
+ releaseTurnSlots()
+ if turnErr != nil || result == nil {
+ return
+ }
+ h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
+ h.submitUsageRecordTask(func(taskCtx context.Context) {
+ if err := h.gatewayService.RecordUsage(taskCtx, &service.OpenAIRecordUsageInput{
+ Result: result,
+ APIKey: apiKey,
+ User: apiKey.User,
+ Account: account,
+ Subscription: subscription,
+ UserAgent: userAgent,
+ IPAddress: clientIP,
+ APIKeyService: h.apiKeyService,
+ }); err != nil {
+ reqLog.Error("openai.websocket_record_usage_failed",
+ zap.Int64("account_id", account.ID),
+ zap.String("request_id", result.RequestID),
+ zap.Error(err),
+ )
+ }
+ })
+ },
+ }
+
+ if err := h.gatewayService.ProxyResponsesWebSocketFromClient(ctx, c, wsConn, account, token, firstMessage, hooks); err != nil {
+ h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
+ closeStatus, closeReason := summarizeWSCloseErrorForLog(err)
+ reqLog.Warn("openai.websocket_proxy_failed",
+ zap.Int64("account_id", account.ID),
+ zap.Error(err),
+ zap.String("close_status", closeStatus),
+ zap.String("close_reason", closeReason),
+ )
+ var closeErr *service.OpenAIWSClientCloseError
+ if errors.As(err, &closeErr) {
+ closeOpenAIClientWS(wsConn, closeErr.StatusCode(), closeErr.Reason())
+ return
+ }
+ closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "upstream websocket proxy failed")
+ return
+ }
+ reqLog.Info("openai.websocket_ingress_closed", zap.Int64("account_id", account.ID))
+}
+
+func (h *OpenAIGatewayHandler) recoverResponsesPanic(c *gin.Context, streamStarted *bool) {
+ recovered := recover()
+ if recovered == nil {
+ return
+ }
+
+ started := false
+ if streamStarted != nil {
+ started = *streamStarted
+ }
+ wroteFallback := h.ensureForwardErrorResponse(c, started)
+ requestLogger(c, "handler.openai_gateway.responses").Error(
+ "openai.responses_panic_recovered",
+ zap.Bool("fallback_error_response_written", wroteFallback),
+ zap.Any("panic", recovered),
+ zap.ByteString("stack", debug.Stack()),
+ )
+}
+
+func (h *OpenAIGatewayHandler) ensureResponsesDependencies(c *gin.Context, reqLog *zap.Logger) bool {
+ missing := h.missingResponsesDependencies()
+ if len(missing) == 0 {
+ return true
+ }
+
+ if reqLog == nil {
+ reqLog = requestLogger(c, "handler.openai_gateway.responses")
+ }
+ reqLog.Error("openai.handler_dependencies_missing", zap.Strings("missing_dependencies", missing))
+
+ if c != nil && c.Writer != nil && !c.Writer.Written() {
+ c.JSON(http.StatusServiceUnavailable, gin.H{
+ "error": gin.H{
+ "type": "api_error",
+ "message": "Service temporarily unavailable",
+ },
+ })
+ }
+ return false
+}
+
+func (h *OpenAIGatewayHandler) missingResponsesDependencies() []string {
+ missing := make([]string, 0, 5)
+ if h == nil {
+ return append(missing, "handler")
+ }
+ if h.gatewayService == nil {
+ missing = append(missing, "gatewayService")
+ }
+ if h.billingCacheService == nil {
+ missing = append(missing, "billingCacheService")
+ }
+ if h.apiKeyService == nil {
+ missing = append(missing, "apiKeyService")
+ }
+ if h.concurrencyHelper == nil || h.concurrencyHelper.concurrencyService == nil {
+ missing = append(missing, "concurrencyHelper")
+ }
+ return missing
+}
+
func getContextInt64(c *gin.Context, key string) (int64, bool) {
if c == nil || key == "" {
return 0, false
@@ -444,6 +892,14 @@ func (h *OpenAIGatewayHandler) submitUsageRecordTask(task service.UsageRecordTas
// 回退路径:worker 池未注入时同步执行,避免退回到无界 goroutine 模式。
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
+ defer func() {
+ if recovered := recover(); recovered != nil {
+ logger.L().With(
+ zap.String("component", "handler.openai_gateway.responses"),
+ zap.Any("panic", recovered),
+ ).Error("openai.usage_record_task_panic_recovered")
+ }
+ }()
task(ctx)
}
@@ -515,19 +971,8 @@ func (h *OpenAIGatewayHandler) handleStreamingAwareError(c *gin.Context, status
// Stream already started, send error as SSE event then close
flusher, ok := c.Writer.(http.Flusher)
if ok {
- // Send error event in OpenAI SSE format with proper JSON marshaling
- errorData := map[string]any{
- "error": map[string]string{
- "type": errType,
- "message": message,
- },
- }
- jsonBytes, err := json.Marshal(errorData)
- if err != nil {
- _ = c.Error(err)
- return
- }
- errorEvent := fmt.Sprintf("event: error\ndata: %s\n\n", string(jsonBytes))
+ // SSE 错误事件固定 schema,使用 Quote 直拼可避免额外 Marshal 分配。
+ errorEvent := "event: error\ndata: " + `{"error":{"type":` + strconv.Quote(errType) + `,"message":` + strconv.Quote(message) + `}}` + "\n\n"
if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil {
_ = c.Error(err)
}
@@ -549,6 +994,16 @@ func (h *OpenAIGatewayHandler) ensureForwardErrorResponse(c *gin.Context, stream
return true
}
+func shouldLogOpenAIForwardFailureAsWarn(c *gin.Context, wroteFallback bool) bool {
+ if wroteFallback {
+ return false
+ }
+ if c == nil || c.Writer == nil {
+ return false
+ }
+ return c.Writer.Written()
+}
+
// errorResponse returns OpenAI API format error response
func (h *OpenAIGatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) {
c.JSON(status, gin.H{
@@ -558,3 +1013,61 @@ func (h *OpenAIGatewayHandler) errorResponse(c *gin.Context, status int, errType
},
})
}
+
+func setOpenAIClientTransportHTTP(c *gin.Context) {
+ service.SetOpenAIClientTransport(c, service.OpenAIClientTransportHTTP)
+}
+
+func setOpenAIClientTransportWS(c *gin.Context) {
+ service.SetOpenAIClientTransport(c, service.OpenAIClientTransportWS)
+}
+
+func openAIWSIngressFallbackSessionSeed(userID, apiKeyID int64, groupID *int64) string {
+ gid := int64(0)
+ if groupID != nil {
+ gid = *groupID
+ }
+ return fmt.Sprintf("openai_ws_ingress:%d:%d:%d", gid, userID, apiKeyID)
+}
+
+func isOpenAIWSUpgradeRequest(r *http.Request) bool {
+ if r == nil {
+ return false
+ }
+ if !strings.EqualFold(strings.TrimSpace(r.Header.Get("Upgrade")), "websocket") {
+ return false
+ }
+ return strings.Contains(strings.ToLower(strings.TrimSpace(r.Header.Get("Connection"))), "upgrade")
+}
+
+func closeOpenAIClientWS(conn *coderws.Conn, status coderws.StatusCode, reason string) {
+ if conn == nil {
+ return
+ }
+ reason = strings.TrimSpace(reason)
+ if len(reason) > 120 {
+ reason = reason[:120]
+ }
+ _ = conn.Close(status, reason)
+ _ = conn.CloseNow()
+}
+
+func summarizeWSCloseErrorForLog(err error) (string, string) {
+ if err == nil {
+ return "-", "-"
+ }
+ statusCode := coderws.CloseStatus(err)
+ if statusCode == -1 {
+ return "-", "-"
+ }
+ closeStatus := fmt.Sprintf("%d(%s)", int(statusCode), statusCode.String())
+ closeReason := "-"
+ var closeErr coderws.CloseError
+ if errors.As(err, &closeErr) {
+ reason := strings.TrimSpace(closeErr.Reason)
+ if reason != "" {
+ closeReason = reason
+ }
+ }
+ return closeStatus, closeReason
+}
diff --git a/backend/internal/handler/openai_gateway_handler_test.go b/backend/internal/handler/openai_gateway_handler_test.go
index 1ca52c2d..a26b3a0c 100644
--- a/backend/internal/handler/openai_gateway_handler_test.go
+++ b/backend/internal/handler/openai_gateway_handler_test.go
@@ -1,12 +1,19 @@
package handler
import (
+ "context"
"encoding/json"
+ "errors"
"net/http"
"net/http/httptest"
"strings"
"testing"
+ "time"
+ pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
+ "github.com/Wei-Shaw/sub2api/internal/server/middleware"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ coderws "github.com/coder/websocket"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -105,6 +112,27 @@ func TestOpenAIHandleStreamingAwareError_NonStreaming(t *testing.T) {
assert.Equal(t, "test error", errorObj["message"])
}
+func TestReadRequestBodyWithPrealloc(t *testing.T) {
+ payload := `{"model":"gpt-5","input":"hello"}`
+ req := httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(payload))
+ req.ContentLength = int64(len(payload))
+
+ body, err := pkghttputil.ReadRequestBodyWithPrealloc(req)
+ require.NoError(t, err)
+ require.Equal(t, payload, string(body))
+}
+
+func TestReadRequestBodyWithPrealloc_MaxBytesError(t *testing.T) {
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(strings.Repeat("x", 8)))
+ req.Body = http.MaxBytesReader(rec, req.Body, 4)
+
+ _, err := pkghttputil.ReadRequestBodyWithPrealloc(req)
+ require.Error(t, err)
+ var maxErr *http.MaxBytesError
+ require.ErrorAs(t, err, &maxErr)
+}
+
func TestOpenAIEnsureForwardErrorResponse_WritesFallbackWhenNotWritten(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
@@ -141,6 +169,387 @@ func TestOpenAIEnsureForwardErrorResponse_DoesNotOverrideWrittenResponse(t *test
assert.Equal(t, "already written", w.Body.String())
}
+func TestShouldLogOpenAIForwardFailureAsWarn(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ t.Run("fallback_written_should_not_downgrade", func(t *testing.T) {
+ w := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(w)
+ c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
+ require.False(t, shouldLogOpenAIForwardFailureAsWarn(c, true))
+ })
+
+ t.Run("context_nil_should_not_downgrade", func(t *testing.T) {
+ require.False(t, shouldLogOpenAIForwardFailureAsWarn(nil, false))
+ })
+
+ t.Run("response_not_written_should_not_downgrade", func(t *testing.T) {
+ w := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(w)
+ c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
+ require.False(t, shouldLogOpenAIForwardFailureAsWarn(c, false))
+ })
+
+ t.Run("response_already_written_should_downgrade", func(t *testing.T) {
+ w := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(w)
+ c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
+ c.String(http.StatusForbidden, "already written")
+ require.True(t, shouldLogOpenAIForwardFailureAsWarn(c, false))
+ })
+}
+
+func TestOpenAIRecoverResponsesPanic_WritesFallbackResponse(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ w := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(w)
+ c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
+
+ h := &OpenAIGatewayHandler{}
+ streamStarted := false
+ require.NotPanics(t, func() {
+ func() {
+ defer h.recoverResponsesPanic(c, &streamStarted)
+ panic("test panic")
+ }()
+ })
+
+ require.Equal(t, http.StatusBadGateway, w.Code)
+
+ var parsed map[string]any
+ err := json.Unmarshal(w.Body.Bytes(), &parsed)
+ require.NoError(t, err)
+
+ errorObj, ok := parsed["error"].(map[string]any)
+ require.True(t, ok)
+ assert.Equal(t, "upstream_error", errorObj["type"])
+ assert.Equal(t, "Upstream request failed", errorObj["message"])
+}
+
+func TestOpenAIRecoverResponsesPanic_NoPanicNoWrite(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ w := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(w)
+ c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
+
+ h := &OpenAIGatewayHandler{}
+ streamStarted := false
+ require.NotPanics(t, func() {
+ func() {
+ defer h.recoverResponsesPanic(c, &streamStarted)
+ }()
+ })
+
+ require.False(t, c.Writer.Written())
+ assert.Equal(t, "", w.Body.String())
+}
+
+func TestOpenAIRecoverResponsesPanic_DoesNotOverrideWrittenResponse(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ w := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(w)
+ c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
+ c.String(http.StatusTeapot, "already written")
+
+ h := &OpenAIGatewayHandler{}
+ streamStarted := false
+ require.NotPanics(t, func() {
+ func() {
+ defer h.recoverResponsesPanic(c, &streamStarted)
+ panic("test panic")
+ }()
+ })
+
+ require.Equal(t, http.StatusTeapot, w.Code)
+ assert.Equal(t, "already written", w.Body.String())
+}
+
+func TestOpenAIMissingResponsesDependencies(t *testing.T) {
+ t.Run("nil_handler", func(t *testing.T) {
+ var h *OpenAIGatewayHandler
+ require.Equal(t, []string{"handler"}, h.missingResponsesDependencies())
+ })
+
+ t.Run("all_dependencies_missing", func(t *testing.T) {
+ h := &OpenAIGatewayHandler{}
+ require.Equal(t,
+ []string{"gatewayService", "billingCacheService", "apiKeyService", "concurrencyHelper"},
+ h.missingResponsesDependencies(),
+ )
+ })
+
+ t.Run("all_dependencies_present", func(t *testing.T) {
+ h := &OpenAIGatewayHandler{
+ gatewayService: &service.OpenAIGatewayService{},
+ billingCacheService: &service.BillingCacheService{},
+ apiKeyService: &service.APIKeyService{},
+ concurrencyHelper: &ConcurrencyHelper{
+ concurrencyService: &service.ConcurrencyService{},
+ },
+ }
+ require.Empty(t, h.missingResponsesDependencies())
+ })
+}
+
+func TestOpenAIEnsureResponsesDependencies(t *testing.T) {
+ t.Run("missing_dependencies_returns_503", func(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ w := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(w)
+ c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
+
+ h := &OpenAIGatewayHandler{}
+ ok := h.ensureResponsesDependencies(c, nil)
+
+ require.False(t, ok)
+ require.Equal(t, http.StatusServiceUnavailable, w.Code)
+ var parsed map[string]any
+ err := json.Unmarshal(w.Body.Bytes(), &parsed)
+ require.NoError(t, err)
+ errorObj, exists := parsed["error"].(map[string]any)
+ require.True(t, exists)
+ assert.Equal(t, "api_error", errorObj["type"])
+ assert.Equal(t, "Service temporarily unavailable", errorObj["message"])
+ })
+
+ t.Run("already_written_response_not_overridden", func(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ w := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(w)
+ c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
+ c.String(http.StatusTeapot, "already written")
+
+ h := &OpenAIGatewayHandler{}
+ ok := h.ensureResponsesDependencies(c, nil)
+
+ require.False(t, ok)
+ require.Equal(t, http.StatusTeapot, w.Code)
+ assert.Equal(t, "already written", w.Body.String())
+ })
+
+ t.Run("dependencies_ready_returns_true_and_no_write", func(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ w := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(w)
+ c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
+
+ h := &OpenAIGatewayHandler{
+ gatewayService: &service.OpenAIGatewayService{},
+ billingCacheService: &service.BillingCacheService{},
+ apiKeyService: &service.APIKeyService{},
+ concurrencyHelper: &ConcurrencyHelper{
+ concurrencyService: &service.ConcurrencyService{},
+ },
+ }
+ ok := h.ensureResponsesDependencies(c, nil)
+
+ require.True(t, ok)
+ require.False(t, c.Writer.Written())
+ assert.Equal(t, "", w.Body.String())
+ })
+}
+
+func TestOpenAIResponses_MissingDependencies_ReturnsServiceUnavailable(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ w := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(w)
+ c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(`{"model":"gpt-5","stream":false}`))
+ c.Request.Header.Set("Content-Type", "application/json")
+
+ groupID := int64(2)
+ c.Set(string(middleware.ContextKeyAPIKey), &service.APIKey{
+ ID: 10,
+ GroupID: &groupID,
+ })
+ c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{
+ UserID: 1,
+ Concurrency: 1,
+ })
+
+ // 故意使用未初始化依赖,验证快速失败而不是崩溃。
+ h := &OpenAIGatewayHandler{}
+ require.NotPanics(t, func() {
+ h.Responses(c)
+ })
+
+ require.Equal(t, http.StatusServiceUnavailable, w.Code)
+
+ var parsed map[string]any
+ err := json.Unmarshal(w.Body.Bytes(), &parsed)
+ require.NoError(t, err)
+
+ errorObj, ok := parsed["error"].(map[string]any)
+ require.True(t, ok)
+ assert.Equal(t, "api_error", errorObj["type"])
+ assert.Equal(t, "Service temporarily unavailable", errorObj["message"])
+}
+
+func TestOpenAIResponses_SetsClientTransportHTTP(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ w := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(w)
+ c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", strings.NewReader(`{"model":"gpt-5"}`))
+ c.Request.Header.Set("Content-Type", "application/json")
+
+ h := &OpenAIGatewayHandler{}
+ h.Responses(c)
+
+ require.Equal(t, http.StatusUnauthorized, w.Code)
+ require.Equal(t, service.OpenAIClientTransportHTTP, service.GetOpenAIClientTransport(c))
+}
+
+func TestOpenAIResponses_RejectsMessageIDAsPreviousResponseID(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ w := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(w)
+ c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", strings.NewReader(
+ `{"model":"gpt-5.1","stream":false,"previous_response_id":"msg_123456","input":[{"type":"input_text","text":"hello"}]}`,
+ ))
+ c.Request.Header.Set("Content-Type", "application/json")
+
+ groupID := int64(2)
+ c.Set(string(middleware.ContextKeyAPIKey), &service.APIKey{
+ ID: 101,
+ GroupID: &groupID,
+ User: &service.User{ID: 1},
+ })
+ c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{
+ UserID: 1,
+ Concurrency: 1,
+ })
+
+ h := newOpenAIHandlerForPreviousResponseIDValidation(t, nil)
+ h.Responses(c)
+
+ require.Equal(t, http.StatusBadRequest, w.Code)
+ require.Contains(t, w.Body.String(), "previous_response_id must be a response.id")
+}
+
+func TestOpenAIResponsesWebSocket_SetsClientTransportWSWhenUpgradeValid(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ w := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(w)
+ c.Request = httptest.NewRequest(http.MethodGet, "/openai/v1/responses", nil)
+ c.Request.Header.Set("Upgrade", "websocket")
+ c.Request.Header.Set("Connection", "Upgrade")
+
+ h := &OpenAIGatewayHandler{}
+ h.ResponsesWebSocket(c)
+
+ require.Equal(t, http.StatusUnauthorized, w.Code)
+ require.Equal(t, service.OpenAIClientTransportWS, service.GetOpenAIClientTransport(c))
+}
+
+func TestOpenAIResponsesWebSocket_InvalidUpgradeDoesNotSetTransport(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ w := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(w)
+ c.Request = httptest.NewRequest(http.MethodGet, "/openai/v1/responses", nil)
+
+ h := &OpenAIGatewayHandler{}
+ h.ResponsesWebSocket(c)
+
+ require.Equal(t, http.StatusUpgradeRequired, w.Code)
+ require.Equal(t, service.OpenAIClientTransportUnknown, service.GetOpenAIClientTransport(c))
+}
+
+func TestOpenAIResponsesWebSocket_RejectsMessageIDAsPreviousResponseID(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ h := newOpenAIHandlerForPreviousResponseIDValidation(t, nil)
+ wsServer := newOpenAIWSHandlerTestServer(t, h, middleware.AuthSubject{UserID: 1, Concurrency: 1})
+ defer wsServer.Close()
+
+ dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
+ clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http")+"/openai/v1/responses", nil)
+ cancelDial()
+ require.NoError(t, err)
+ defer func() {
+ _ = clientConn.CloseNow()
+ }()
+
+ writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
+ err = clientConn.Write(writeCtx, coderws.MessageText, []byte(
+ `{"type":"response.create","model":"gpt-5.1","stream":false,"previous_response_id":"msg_abc123"}`,
+ ))
+ cancelWrite()
+ require.NoError(t, err)
+
+ readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second)
+ _, _, err = clientConn.Read(readCtx)
+ cancelRead()
+ require.Error(t, err)
+ var closeErr coderws.CloseError
+ require.ErrorAs(t, err, &closeErr)
+ require.Equal(t, coderws.StatusPolicyViolation, closeErr.Code)
+ require.Contains(t, strings.ToLower(closeErr.Reason), "previous_response_id")
+}
+
+func TestOpenAIResponsesWebSocket_PreviousResponseIDKindLoggedBeforeAcquireFailure(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ cache := &concurrencyCacheMock{
+ acquireUserSlotFn: func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
+ return false, errors.New("user slot unavailable")
+ },
+ }
+ h := newOpenAIHandlerForPreviousResponseIDValidation(t, cache)
+ wsServer := newOpenAIWSHandlerTestServer(t, h, middleware.AuthSubject{UserID: 1, Concurrency: 1})
+ defer wsServer.Close()
+
+ dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
+ clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http")+"/openai/v1/responses", nil)
+ cancelDial()
+ require.NoError(t, err)
+ defer func() {
+ _ = clientConn.CloseNow()
+ }()
+
+ writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
+ err = clientConn.Write(writeCtx, coderws.MessageText, []byte(
+ `{"type":"response.create","model":"gpt-5.1","stream":false,"previous_response_id":"resp_prev_123"}`,
+ ))
+ cancelWrite()
+ require.NoError(t, err)
+
+ readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second)
+ _, _, err = clientConn.Read(readCtx)
+ cancelRead()
+ require.Error(t, err)
+ var closeErr coderws.CloseError
+ require.ErrorAs(t, err, &closeErr)
+ require.Equal(t, coderws.StatusInternalError, closeErr.Code)
+ require.Contains(t, strings.ToLower(closeErr.Reason), "failed to acquire user concurrency slot")
+}
+
+func TestSetOpenAIClientTransportHTTP(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ w := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(w)
+
+ setOpenAIClientTransportHTTP(c)
+ require.Equal(t, service.OpenAIClientTransportHTTP, service.GetOpenAIClientTransport(c))
+}
+
+func TestSetOpenAIClientTransportWS(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ w := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(w)
+
+ setOpenAIClientTransportWS(c)
+ require.Equal(t, service.OpenAIClientTransportWS, service.GetOpenAIClientTransport(c))
+}
+
// TestOpenAIHandler_GjsonExtraction 验证 gjson 从请求体中提取 model/stream 的正确性
func TestOpenAIHandler_GjsonExtraction(t *testing.T) {
tests := []struct {
@@ -228,3 +637,41 @@ func TestOpenAIHandler_InstructionsInjection(t *testing.T) {
require.NoError(t, setErr)
require.True(t, gjson.ValidBytes(result))
}
+
+func newOpenAIHandlerForPreviousResponseIDValidation(t *testing.T, cache *concurrencyCacheMock) *OpenAIGatewayHandler {
+ t.Helper()
+ if cache == nil {
+ cache = &concurrencyCacheMock{
+ acquireUserSlotFn: func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
+ return true, nil
+ },
+ acquireAccountSlotFn: func(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
+ return true, nil
+ },
+ }
+ }
+ return &OpenAIGatewayHandler{
+ gatewayService: &service.OpenAIGatewayService{},
+ billingCacheService: &service.BillingCacheService{},
+ apiKeyService: &service.APIKeyService{},
+ concurrencyHelper: NewConcurrencyHelper(service.NewConcurrencyService(cache), SSEPingFormatNone, time.Second),
+ }
+}
+
+func newOpenAIWSHandlerTestServer(t *testing.T, h *OpenAIGatewayHandler, subject middleware.AuthSubject) *httptest.Server {
+ t.Helper()
+ groupID := int64(2)
+ apiKey := &service.APIKey{
+ ID: 101,
+ GroupID: &groupID,
+ User: &service.User{ID: subject.UserID},
+ }
+ router := gin.New()
+ router.Use(func(c *gin.Context) {
+ c.Set(string(middleware.ContextKeyAPIKey), apiKey)
+ c.Set(string(middleware.ContextKeyUser), subject)
+ c.Next()
+ })
+ router.GET("/openai/v1/responses", h.ResponsesWebSocket)
+ return httptest.NewServer(router)
+}
diff --git a/backend/internal/handler/ops_error_logger.go b/backend/internal/handler/ops_error_logger.go
index ab9a2167..2f53d655 100644
--- a/backend/internal/handler/ops_error_logger.go
+++ b/backend/internal/handler/ops_error_logger.go
@@ -311,6 +311,35 @@ type opsCaptureWriter struct {
buf bytes.Buffer
}
+const opsCaptureWriterLimit = 64 * 1024
+
+var opsCaptureWriterPool = sync.Pool{
+ New: func() any {
+ return &opsCaptureWriter{limit: opsCaptureWriterLimit}
+ },
+}
+
+func acquireOpsCaptureWriter(rw gin.ResponseWriter) *opsCaptureWriter {
+ w, ok := opsCaptureWriterPool.Get().(*opsCaptureWriter)
+ if !ok || w == nil {
+ w = &opsCaptureWriter{}
+ }
+ w.ResponseWriter = rw
+ w.limit = opsCaptureWriterLimit
+ w.buf.Reset()
+ return w
+}
+
+func releaseOpsCaptureWriter(w *opsCaptureWriter) {
+ if w == nil {
+ return
+ }
+ w.ResponseWriter = nil
+ w.limit = opsCaptureWriterLimit
+ w.buf.Reset()
+ opsCaptureWriterPool.Put(w)
+}
+
func (w *opsCaptureWriter) Write(b []byte) (int, error) {
if w.Status() >= 400 && w.limit > 0 && w.buf.Len() < w.limit {
remaining := w.limit - w.buf.Len()
@@ -342,7 +371,16 @@ func (w *opsCaptureWriter) WriteString(s string) (int, error) {
// - Streaming errors after the response has started (SSE) may still need explicit logging.
func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
return func(c *gin.Context) {
- w := &opsCaptureWriter{ResponseWriter: c.Writer, limit: 64 * 1024}
+ originalWriter := c.Writer
+ w := acquireOpsCaptureWriter(originalWriter)
+ defer func() {
+ // Restore the original writer before returning so outer middlewares
+ // don't observe a pooled wrapper that has been released.
+ if c.Writer == w {
+ c.Writer = originalWriter
+ }
+ releaseOpsCaptureWriter(w)
+ }()
c.Writer = w
c.Next()
@@ -624,8 +662,10 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
requestID = c.Writer.Header().Get("x-request-id")
}
- phase := classifyOpsPhase(parsed.ErrorType, parsed.Message, parsed.Code)
- isBusinessLimited := classifyOpsIsBusinessLimited(parsed.ErrorType, phase, parsed.Code, status, parsed.Message)
+ normalizedType := normalizeOpsErrorType(parsed.ErrorType, parsed.Code)
+
+ phase := classifyOpsPhase(normalizedType, parsed.Message, parsed.Code)
+ isBusinessLimited := classifyOpsIsBusinessLimited(normalizedType, phase, parsed.Code, status, parsed.Message)
errorOwner := classifyOpsErrorOwner(phase, parsed.Message)
errorSource := classifyOpsErrorSource(phase, parsed.Message)
@@ -647,8 +687,8 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
UserAgent: c.GetHeader("User-Agent"),
ErrorPhase: phase,
- ErrorType: normalizeOpsErrorType(parsed.ErrorType, parsed.Code),
- Severity: classifyOpsSeverity(parsed.ErrorType, status),
+ ErrorType: normalizedType,
+ Severity: classifyOpsSeverity(normalizedType, status),
StatusCode: status,
IsBusinessLimited: isBusinessLimited,
IsCountTokens: isCountTokensRequest(c),
@@ -660,7 +700,7 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
ErrorSource: errorSource,
ErrorOwner: errorOwner,
- IsRetryable: classifyOpsIsRetryable(parsed.ErrorType, status),
+ IsRetryable: classifyOpsIsRetryable(normalizedType, status),
RetryCount: 0,
CreatedAt: time.Now(),
}
@@ -901,8 +941,29 @@ func guessPlatformFromPath(path string) string {
}
}
+// isKnownOpsErrorType returns true if t is a recognized error type used by the
+// ops classification pipeline. Upstream proxies sometimes return garbage values
+// (e.g. the Go-serialized literal "") which would pollute phase/severity
+// classification if accepted blindly.
+func isKnownOpsErrorType(t string) bool {
+ switch t {
+ case "invalid_request_error",
+ "authentication_error",
+ "rate_limit_error",
+ "billing_error",
+ "subscription_error",
+ "upstream_error",
+ "overloaded_error",
+ "api_error",
+ "not_found_error",
+ "forbidden_error":
+ return true
+ }
+ return false
+}
+
func normalizeOpsErrorType(errType string, code string) string {
- if errType != "" {
+ if errType != "" && isKnownOpsErrorType(errType) {
return errType
}
switch strings.TrimSpace(code) {
diff --git a/backend/internal/handler/ops_error_logger_test.go b/backend/internal/handler/ops_error_logger_test.go
index a11fa1f2..679dd4ce 100644
--- a/backend/internal/handler/ops_error_logger_test.go
+++ b/backend/internal/handler/ops_error_logger_test.go
@@ -6,6 +6,7 @@ import (
"sync"
"testing"
+ middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
@@ -173,3 +174,103 @@ func TestEnqueueOpsErrorLog_EarlyReturnBranches(t *testing.T) {
enqueueOpsErrorLog(ops, entry)
require.Equal(t, int64(0), OpsErrorLogEnqueuedTotal())
}
+
+func TestOpsCaptureWriterPool_ResetOnRelease(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodGet, "/test", nil)
+
+ writer := acquireOpsCaptureWriter(c.Writer)
+ require.NotNil(t, writer)
+ _, err := writer.buf.WriteString("temp-error-body")
+ require.NoError(t, err)
+
+ releaseOpsCaptureWriter(writer)
+
+ reused := acquireOpsCaptureWriter(c.Writer)
+ defer releaseOpsCaptureWriter(reused)
+
+ require.Zero(t, reused.buf.Len(), "writer should be reset before reuse")
+}
+
+func TestOpsErrorLoggerMiddleware_DoesNotBreakOuterMiddlewares(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ r := gin.New()
+ r.Use(middleware2.Recovery())
+ r.Use(middleware2.RequestLogger())
+ r.Use(middleware2.Logger())
+ r.GET("/v1/messages", OpsErrorLoggerMiddleware(nil), func(c *gin.Context) {
+ c.Status(http.StatusNoContent)
+ })
+
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/v1/messages", nil)
+
+ require.NotPanics(t, func() {
+ r.ServeHTTP(rec, req)
+ })
+ require.Equal(t, http.StatusNoContent, rec.Code)
+}
+
+func TestIsKnownOpsErrorType(t *testing.T) {
+ known := []string{
+ "invalid_request_error",
+ "authentication_error",
+ "rate_limit_error",
+ "billing_error",
+ "subscription_error",
+ "upstream_error",
+ "overloaded_error",
+ "api_error",
+ "not_found_error",
+ "forbidden_error",
+ }
+ for _, k := range known {
+ require.True(t, isKnownOpsErrorType(k), "expected known: %s", k)
+ }
+
+ unknown := []string{"", "null", "", "random_error", "some_new_type", "\u003e"}
+ for _, u := range unknown {
+ require.False(t, isKnownOpsErrorType(u), "expected unknown: %q", u)
+ }
+}
+
+func TestNormalizeOpsErrorType(t *testing.T) {
+ tests := []struct {
+ name string
+ errType string
+ code string
+ want string
+ }{
+ // Known types pass through.
+ {"known invalid_request_error", "invalid_request_error", "", "invalid_request_error"},
+ {"known rate_limit_error", "rate_limit_error", "", "rate_limit_error"},
+ {"known upstream_error", "upstream_error", "", "upstream_error"},
+
+ // Unknown/garbage types are rejected and fall through to code-based or default.
+ {"nil literal from upstream", "", "", "api_error"},
+ {"null string", "null", "", "api_error"},
+ {"random string", "something_weird", "", "api_error"},
+
+ // Unknown type but known code still maps correctly.
+ {"nil with INSUFFICIENT_BALANCE code", "", "INSUFFICIENT_BALANCE", "billing_error"},
+ {"nil with USAGE_LIMIT_EXCEEDED code", "", "USAGE_LIMIT_EXCEEDED", "subscription_error"},
+
+ // Empty type falls through to code-based mapping.
+ {"empty type with balance code", "", "INSUFFICIENT_BALANCE", "billing_error"},
+ {"empty type with subscription code", "", "SUBSCRIPTION_NOT_FOUND", "subscription_error"},
+ {"empty type no code", "", "", "api_error"},
+
+ // Known type overrides conflicting code-based mapping.
+ {"known type overrides conflicting code", "rate_limit_error", "INSUFFICIENT_BALANCE", "rate_limit_error"},
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := normalizeOpsErrorType(tt.errType, tt.code)
+ require.Equal(t, tt.want, got)
+ })
+ }
+}
diff --git a/backend/internal/handler/setting_handler.go b/backend/internal/handler/setting_handler.go
index 2029f116..2141a9ee 100644
--- a/backend/internal/handler/setting_handler.go
+++ b/backend/internal/handler/setting_handler.go
@@ -51,6 +51,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
+ SoraClientEnabled: settings.SoraClientEnabled,
Version: h.version,
})
}
diff --git a/backend/internal/handler/sora_client_handler.go b/backend/internal/handler/sora_client_handler.go
new file mode 100644
index 00000000..80acc833
--- /dev/null
+++ b/backend/internal/handler/sora_client_handler.go
@@ -0,0 +1,979 @@
+package handler
+
+import (
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "strconv"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/logger"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/response"
+ middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/gin-gonic/gin"
+)
+
+const (
+ // 上游模型缓存 TTL
+ modelCacheTTL = 1 * time.Hour // 上游获取成功
+ modelCacheFailedTTL = 2 * time.Minute // 上游获取失败(降级到本地)
+)
+
+// SoraClientHandler 处理 Sora 客户端 API 请求。
+type SoraClientHandler struct {
+ genService *service.SoraGenerationService
+ quotaService *service.SoraQuotaService
+ s3Storage *service.SoraS3Storage
+ soraGatewayService *service.SoraGatewayService
+ gatewayService *service.GatewayService
+ mediaStorage *service.SoraMediaStorage
+ apiKeyService *service.APIKeyService
+
+ // 上游模型缓存
+ modelCacheMu sync.RWMutex
+ cachedFamilies []service.SoraModelFamily
+ modelCacheTime time.Time
+ modelCacheUpstream bool // 是否来自上游(决定 TTL)
+}
+
+// NewSoraClientHandler 创建 Sora 客户端 Handler。
+func NewSoraClientHandler(
+ genService *service.SoraGenerationService,
+ quotaService *service.SoraQuotaService,
+ s3Storage *service.SoraS3Storage,
+ soraGatewayService *service.SoraGatewayService,
+ gatewayService *service.GatewayService,
+ mediaStorage *service.SoraMediaStorage,
+ apiKeyService *service.APIKeyService,
+) *SoraClientHandler {
+ return &SoraClientHandler{
+ genService: genService,
+ quotaService: quotaService,
+ s3Storage: s3Storage,
+ soraGatewayService: soraGatewayService,
+ gatewayService: gatewayService,
+ mediaStorage: mediaStorage,
+ apiKeyService: apiKeyService,
+ }
+}
+
+// GenerateRequest 生成请求。
+type GenerateRequest struct {
+ Model string `json:"model" binding:"required"`
+ Prompt string `json:"prompt" binding:"required"`
+ MediaType string `json:"media_type"` // video / image,默认 video
+ VideoCount int `json:"video_count,omitempty"` // 视频数量(1-3)
+ ImageInput string `json:"image_input,omitempty"` // 参考图(base64 或 URL)
+ APIKeyID *int64 `json:"api_key_id,omitempty"` // 前端传递的 API Key ID
+}
+
+// Generate 异步生成 — 创建 pending 记录后立即返回。
+// POST /api/v1/sora/generate
+func (h *SoraClientHandler) Generate(c *gin.Context) {
+ userID := getUserIDFromContext(c)
+ if userID == 0 {
+ response.Error(c, http.StatusUnauthorized, "未登录")
+ return
+ }
+
+ var req GenerateRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.Error(c, http.StatusBadRequest, "参数错误: "+err.Error())
+ return
+ }
+
+ if req.MediaType == "" {
+ req.MediaType = "video"
+ }
+ req.VideoCount = normalizeVideoCount(req.MediaType, req.VideoCount)
+
+ // 并发数检查(最多 3 个)
+ activeCount, err := h.genService.CountActiveByUser(c.Request.Context(), userID)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ if activeCount >= 3 {
+ response.Error(c, http.StatusTooManyRequests, "同时进行中的任务不能超过 3 个")
+ return
+ }
+
+ // 配额检查(粗略检查,实际文件大小在上传后才知道)
+ if h.quotaService != nil {
+ if err := h.quotaService.CheckQuota(c.Request.Context(), userID, 0); err != nil {
+ var quotaErr *service.QuotaExceededError
+ if errors.As(err, "aErr) {
+ response.Error(c, http.StatusTooManyRequests, "存储配额已满,请删除不需要的作品释放空间")
+ return
+ }
+ response.Error(c, http.StatusForbidden, err.Error())
+ return
+ }
+ }
+
+ // 获取 API Key ID 和 Group ID
+ var apiKeyID *int64
+ var groupID *int64
+
+ if req.APIKeyID != nil && h.apiKeyService != nil {
+ // 前端传递了 api_key_id,需要校验
+ apiKey, err := h.apiKeyService.GetByID(c.Request.Context(), *req.APIKeyID)
+ if err != nil {
+ response.Error(c, http.StatusBadRequest, "API Key 不存在")
+ return
+ }
+ if apiKey.UserID != userID {
+ response.Error(c, http.StatusForbidden, "API Key 不属于当前用户")
+ return
+ }
+ if apiKey.Status != service.StatusAPIKeyActive {
+ response.Error(c, http.StatusForbidden, "API Key 不可用")
+ return
+ }
+ apiKeyID = &apiKey.ID
+ groupID = apiKey.GroupID
+ } else if id, ok := c.Get("api_key_id"); ok {
+ // 兼容 API Key 认证路径(/sora/v1/ 网关路由)
+ if v, ok := id.(int64); ok {
+ apiKeyID = &v
+ }
+ }
+
+ gen, err := h.genService.CreatePending(c.Request.Context(), userID, apiKeyID, req.Model, req.Prompt, req.MediaType)
+ if err != nil {
+ if errors.Is(err, service.ErrSoraGenerationConcurrencyLimit) {
+ response.Error(c, http.StatusTooManyRequests, "同时进行中的任务不能超过 3 个")
+ return
+ }
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ // 启动后台异步生成 goroutine
+ go h.processGeneration(gen.ID, userID, groupID, req.Model, req.Prompt, req.MediaType, req.ImageInput, req.VideoCount)
+
+ response.Success(c, gin.H{
+ "generation_id": gen.ID,
+ "status": gen.Status,
+ })
+}
+
+// processGeneration 后台异步执行 Sora 生成任务。
+// 流程:选择账号 → Forward → 提取媒体 URL → 三层降级存储(S3 → 本地 → 上游)→ 更新记录。
+func (h *SoraClientHandler) processGeneration(genID int64, userID int64, groupID *int64, model, prompt, mediaType, imageInput string, videoCount int) {
+ ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
+ defer cancel()
+
+ // 标记为生成中
+ if err := h.genService.MarkGenerating(ctx, genID, ""); err != nil {
+ if errors.Is(err, service.ErrSoraGenerationStateConflict) {
+ logger.LegacyPrintf("handler.sora_client", "[SoraClient] 任务状态已变化,跳过生成 id=%d", genID)
+ return
+ }
+ logger.LegacyPrintf("handler.sora_client", "[SoraClient] 标记生成中失败 id=%d err=%v", genID, err)
+ return
+ }
+
+ logger.LegacyPrintf(
+ "handler.sora_client",
+ "[SoraClient] 开始生成 id=%d user=%d group=%d model=%s media_type=%s video_count=%d has_image=%v prompt_len=%d",
+ genID,
+ userID,
+ groupIDForLog(groupID),
+ model,
+ mediaType,
+ videoCount,
+ strings.TrimSpace(imageInput) != "",
+ len(strings.TrimSpace(prompt)),
+ )
+
+ // 有 groupID 时由分组决定平台,无 groupID 时用 ForcePlatform 兜底
+ if groupID == nil {
+ ctx = context.WithValue(ctx, ctxkey.ForcePlatform, service.PlatformSora)
+ }
+
+ if h.gatewayService == nil {
+ _ = h.genService.MarkFailed(ctx, genID, "内部错误: gatewayService 未初始化")
+ return
+ }
+
+ // 选择 Sora 账号
+ account, err := h.gatewayService.SelectAccountForModel(ctx, groupID, "", model)
+ if err != nil {
+ logger.LegacyPrintf(
+ "handler.sora_client",
+ "[SoraClient] 选择账号失败 id=%d user=%d group=%d model=%s err=%v",
+ genID,
+ userID,
+ groupIDForLog(groupID),
+ model,
+ err,
+ )
+ _ = h.genService.MarkFailed(ctx, genID, "选择账号失败: "+err.Error())
+ return
+ }
+ logger.LegacyPrintf(
+ "handler.sora_client",
+ "[SoraClient] 选中账号 id=%d user=%d group=%d model=%s account_id=%d account_name=%s platform=%s type=%s",
+ genID,
+ userID,
+ groupIDForLog(groupID),
+ model,
+ account.ID,
+ account.Name,
+ account.Platform,
+ account.Type,
+ )
+
+ // 构建 chat completions 请求体(非流式)
+ body := buildAsyncRequestBody(model, prompt, imageInput, normalizeVideoCount(mediaType, videoCount))
+
+ if h.soraGatewayService == nil {
+ _ = h.genService.MarkFailed(ctx, genID, "内部错误: soraGatewayService 未初始化")
+ return
+ }
+
+ // 创建 mock gin 上下文用于 Forward(捕获响应以提取媒体 URL)
+ recorder := httptest.NewRecorder()
+ mockGinCtx, _ := gin.CreateTestContext(recorder)
+ mockGinCtx.Request, _ = http.NewRequest("POST", "/", nil)
+
+ // 调用 Forward(非流式)
+ result, err := h.soraGatewayService.Forward(ctx, mockGinCtx, account, body, false)
+ if err != nil {
+ logger.LegacyPrintf(
+ "handler.sora_client",
+ "[SoraClient] Forward失败 id=%d account_id=%d model=%s status=%d body=%s err=%v",
+ genID,
+ account.ID,
+ model,
+ recorder.Code,
+ trimForLog(recorder.Body.String(), 400),
+ err,
+ )
+ // 检查是否已取消
+ gen, _ := h.genService.GetByID(ctx, genID, userID)
+ if gen != nil && gen.Status == service.SoraGenStatusCancelled {
+ return
+ }
+ _ = h.genService.MarkFailed(ctx, genID, "生成失败: "+err.Error())
+ return
+ }
+
+ // 提取媒体 URL(优先从 ForwardResult,其次从响应体解析)
+ mediaURL, mediaURLs := extractMediaURLsFromResult(result, recorder)
+ if mediaURL == "" {
+ logger.LegacyPrintf(
+ "handler.sora_client",
+ "[SoraClient] 未提取到媒体URL id=%d account_id=%d model=%s status=%d body=%s",
+ genID,
+ account.ID,
+ model,
+ recorder.Code,
+ trimForLog(recorder.Body.String(), 400),
+ )
+ _ = h.genService.MarkFailed(ctx, genID, "未获取到媒体 URL")
+ return
+ }
+
+ // 检查任务是否已被取消
+ gen, _ := h.genService.GetByID(ctx, genID, userID)
+ if gen != nil && gen.Status == service.SoraGenStatusCancelled {
+ logger.LegacyPrintf("handler.sora_client", "[SoraClient] 任务已取消,跳过存储 id=%d", genID)
+ return
+ }
+
+ // 三层降级存储:S3 → 本地 → 上游临时 URL
+ storedURL, storedURLs, storageType, s3Keys, fileSize := h.storeMediaWithDegradation(ctx, userID, mediaType, mediaURL, mediaURLs)
+
+ usageAdded := false
+ if (storageType == service.SoraStorageTypeS3 || storageType == service.SoraStorageTypeLocal) && fileSize > 0 && h.quotaService != nil {
+ if err := h.quotaService.AddUsage(ctx, userID, fileSize); err != nil {
+ h.cleanupStoredMedia(ctx, storageType, s3Keys, storedURLs)
+ var quotaErr *service.QuotaExceededError
+ if errors.As(err, "aErr) {
+ _ = h.genService.MarkFailed(ctx, genID, "存储配额已满,请删除不需要的作品释放空间")
+ return
+ }
+ _ = h.genService.MarkFailed(ctx, genID, "存储配额更新失败: "+err.Error())
+ return
+ }
+ usageAdded = true
+ }
+
+ // 存储完成后再做一次取消检查,防止取消被 completed 覆盖。
+ gen, _ = h.genService.GetByID(ctx, genID, userID)
+ if gen != nil && gen.Status == service.SoraGenStatusCancelled {
+ logger.LegacyPrintf("handler.sora_client", "[SoraClient] 存储后检测到任务已取消,回滚存储 id=%d", genID)
+ h.cleanupStoredMedia(ctx, storageType, s3Keys, storedURLs)
+ if usageAdded && h.quotaService != nil {
+ _ = h.quotaService.ReleaseUsage(ctx, userID, fileSize)
+ }
+ return
+ }
+
+ // 标记完成
+ if err := h.genService.MarkCompleted(ctx, genID, storedURL, storedURLs, storageType, s3Keys, fileSize); err != nil {
+ if errors.Is(err, service.ErrSoraGenerationStateConflict) {
+ h.cleanupStoredMedia(ctx, storageType, s3Keys, storedURLs)
+ if usageAdded && h.quotaService != nil {
+ _ = h.quotaService.ReleaseUsage(ctx, userID, fileSize)
+ }
+ return
+ }
+ logger.LegacyPrintf("handler.sora_client", "[SoraClient] 标记完成失败 id=%d err=%v", genID, err)
+ return
+ }
+
+ logger.LegacyPrintf("handler.sora_client", "[SoraClient] 生成完成 id=%d storage=%s size=%d", genID, storageType, fileSize)
+}
+
+// storeMediaWithDegradation 实现三层降级存储链:S3 → 本地 → 上游。
+func (h *SoraClientHandler) storeMediaWithDegradation(
+ ctx context.Context, userID int64, mediaType string,
+ mediaURL string, mediaURLs []string,
+) (storedURL string, storedURLs []string, storageType string, s3Keys []string, fileSize int64) {
+ urls := mediaURLs
+ if len(urls) == 0 {
+ urls = []string{mediaURL}
+ }
+
+ // 第一层:尝试 S3
+ if h.s3Storage != nil && h.s3Storage.Enabled(ctx) {
+ keys := make([]string, 0, len(urls))
+ var totalSize int64
+ allOK := true
+ for _, u := range urls {
+ key, size, err := h.s3Storage.UploadFromURL(ctx, userID, u)
+ if err != nil {
+ logger.LegacyPrintf("handler.sora_client", "[SoraClient] S3 上传失败 err=%v", err)
+ allOK = false
+ // 清理已上传的文件
+ if len(keys) > 0 {
+ _ = h.s3Storage.DeleteObjects(ctx, keys)
+ }
+ break
+ }
+ keys = append(keys, key)
+ totalSize += size
+ }
+ if allOK && len(keys) > 0 {
+ accessURLs := make([]string, 0, len(keys))
+ for _, key := range keys {
+ accessURL, err := h.s3Storage.GetAccessURL(ctx, key)
+ if err != nil {
+ logger.LegacyPrintf("handler.sora_client", "[SoraClient] 生成 S3 访问 URL 失败 err=%v", err)
+ _ = h.s3Storage.DeleteObjects(ctx, keys)
+ allOK = false
+ break
+ }
+ accessURLs = append(accessURLs, accessURL)
+ }
+ if allOK && len(accessURLs) > 0 {
+ return accessURLs[0], accessURLs, service.SoraStorageTypeS3, keys, totalSize
+ }
+ }
+ }
+
+ // 第二层:尝试本地存储
+ if h.mediaStorage != nil && h.mediaStorage.Enabled() {
+ storedPaths, err := h.mediaStorage.StoreFromURLs(ctx, mediaType, urls)
+ if err == nil && len(storedPaths) > 0 {
+ firstPath := storedPaths[0]
+ totalSize, sizeErr := h.mediaStorage.TotalSizeByRelativePaths(storedPaths)
+ if sizeErr != nil {
+ logger.LegacyPrintf("handler.sora_client", "[SoraClient] 统计本地文件大小失败 err=%v", sizeErr)
+ }
+ return firstPath, storedPaths, service.SoraStorageTypeLocal, nil, totalSize
+ }
+ logger.LegacyPrintf("handler.sora_client", "[SoraClient] 本地存储失败 err=%v", err)
+ }
+
+ // 第三层:保留上游临时 URL
+ return urls[0], urls, service.SoraStorageTypeUpstream, nil, 0
+}
+
+// buildAsyncRequestBody 构建 Sora 异步生成的 chat completions 请求体。
+func buildAsyncRequestBody(model, prompt, imageInput string, videoCount int) []byte {
+ body := map[string]any{
+ "model": model,
+ "messages": []map[string]string{
+ {"role": "user", "content": prompt},
+ },
+ "stream": false,
+ }
+ if imageInput != "" {
+ body["image_input"] = imageInput
+ }
+ if videoCount > 1 {
+ body["video_count"] = videoCount
+ }
+ b, _ := json.Marshal(body)
+ return b
+}
+
+func normalizeVideoCount(mediaType string, videoCount int) int {
+ if mediaType != "video" {
+ return 1
+ }
+ if videoCount <= 0 {
+ return 1
+ }
+ if videoCount > 3 {
+ return 3
+ }
+ return videoCount
+}
+
+// extractMediaURLsFromResult 从 Forward 结果和响应体中提取媒体 URL。
+// OAuth 路径:ForwardResult.MediaURL 已填充。
+// APIKey 路径:需从响应体解析 media_url / media_urls 字段。
+func extractMediaURLsFromResult(result *service.ForwardResult, recorder *httptest.ResponseRecorder) (string, []string) {
+ // 优先从 ForwardResult 获取(OAuth 路径)
+ if result != nil && result.MediaURL != "" {
+ // 尝试从响应体获取完整 URL 列表
+ if urls := parseMediaURLsFromBody(recorder.Body.Bytes()); len(urls) > 0 {
+ return urls[0], urls
+ }
+ return result.MediaURL, []string{result.MediaURL}
+ }
+
+ // 从响应体解析(APIKey 路径)
+ if urls := parseMediaURLsFromBody(recorder.Body.Bytes()); len(urls) > 0 {
+ return urls[0], urls
+ }
+
+ return "", nil
+}
+
+// parseMediaURLsFromBody 从 JSON 响应体中解析 media_url / media_urls 字段。
+func parseMediaURLsFromBody(body []byte) []string {
+ if len(body) == 0 {
+ return nil
+ }
+ var resp map[string]any
+ if err := json.Unmarshal(body, &resp); err != nil {
+ return nil
+ }
+
+ // 优先 media_urls(多图数组)
+ if rawURLs, ok := resp["media_urls"]; ok {
+ if arr, ok := rawURLs.([]any); ok && len(arr) > 0 {
+ urls := make([]string, 0, len(arr))
+ for _, item := range arr {
+ if s, ok := item.(string); ok && s != "" {
+ urls = append(urls, s)
+ }
+ }
+ if len(urls) > 0 {
+ return urls
+ }
+ }
+ }
+
+ // 回退到 media_url(单个 URL)
+ if url, ok := resp["media_url"].(string); ok && url != "" {
+ return []string{url}
+ }
+
+ return nil
+}
+
+// ListGenerations 查询生成记录列表。
+// GET /api/v1/sora/generations
+func (h *SoraClientHandler) ListGenerations(c *gin.Context) {
+ userID := getUserIDFromContext(c)
+ if userID == 0 {
+ response.Error(c, http.StatusUnauthorized, "未登录")
+ return
+ }
+
+ page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
+ pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
+
+ params := service.SoraGenerationListParams{
+ UserID: userID,
+ Status: c.Query("status"),
+ StorageType: c.Query("storage_type"),
+ MediaType: c.Query("media_type"),
+ Page: page,
+ PageSize: pageSize,
+ }
+
+ gens, total, err := h.genService.List(c.Request.Context(), params)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ // 为 S3 记录动态生成预签名 URL
+ for _, gen := range gens {
+ _ = h.genService.ResolveMediaURLs(c.Request.Context(), gen)
+ }
+
+ response.Success(c, gin.H{
+ "data": gens,
+ "total": total,
+ "page": page,
+ })
+}
+
+// GetGeneration 查询生成记录详情。
+// GET /api/v1/sora/generations/:id
+func (h *SoraClientHandler) GetGeneration(c *gin.Context) {
+ userID := getUserIDFromContext(c)
+ if userID == 0 {
+ response.Error(c, http.StatusUnauthorized, "未登录")
+ return
+ }
+
+ id, err := strconv.ParseInt(c.Param("id"), 10, 64)
+ if err != nil {
+ response.Error(c, http.StatusBadRequest, "无效的 ID")
+ return
+ }
+
+ gen, err := h.genService.GetByID(c.Request.Context(), id, userID)
+ if err != nil {
+ response.Error(c, http.StatusNotFound, err.Error())
+ return
+ }
+
+ _ = h.genService.ResolveMediaURLs(c.Request.Context(), gen)
+ response.Success(c, gen)
+}
+
+// DeleteGeneration 删除生成记录。
+// DELETE /api/v1/sora/generations/:id
+func (h *SoraClientHandler) DeleteGeneration(c *gin.Context) {
+ userID := getUserIDFromContext(c)
+ if userID == 0 {
+ response.Error(c, http.StatusUnauthorized, "未登录")
+ return
+ }
+
+ id, err := strconv.ParseInt(c.Param("id"), 10, 64)
+ if err != nil {
+ response.Error(c, http.StatusBadRequest, "无效的 ID")
+ return
+ }
+
+ gen, err := h.genService.GetByID(c.Request.Context(), id, userID)
+ if err != nil {
+ response.Error(c, http.StatusNotFound, err.Error())
+ return
+ }
+
+ // 先尝试清理本地文件,再删除记录(清理失败不阻塞删除)。
+ if gen.StorageType == service.SoraStorageTypeLocal && h.mediaStorage != nil {
+ paths := gen.MediaURLs
+ if len(paths) == 0 && gen.MediaURL != "" {
+ paths = []string{gen.MediaURL}
+ }
+ if err := h.mediaStorage.DeleteByRelativePaths(paths); err != nil {
+ logger.LegacyPrintf("handler.sora_client", "[SoraClient] 删除本地文件失败 id=%d err=%v", id, err)
+ }
+ }
+
+ if err := h.genService.Delete(c.Request.Context(), id, userID); err != nil {
+ response.Error(c, http.StatusNotFound, err.Error())
+ return
+ }
+
+ response.Success(c, gin.H{"message": "已删除"})
+}
+
+// GetQuota 查询用户存储配额。
+// GET /api/v1/sora/quota
+func (h *SoraClientHandler) GetQuota(c *gin.Context) {
+ userID := getUserIDFromContext(c)
+ if userID == 0 {
+ response.Error(c, http.StatusUnauthorized, "未登录")
+ return
+ }
+
+ if h.quotaService == nil {
+ response.Success(c, service.QuotaInfo{QuotaSource: "unlimited", Source: "unlimited"})
+ return
+ }
+
+ quota, err := h.quotaService.GetQuota(c.Request.Context(), userID)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, quota)
+}
+
+// CancelGeneration 取消生成任务。
+// POST /api/v1/sora/generations/:id/cancel
+func (h *SoraClientHandler) CancelGeneration(c *gin.Context) {
+ userID := getUserIDFromContext(c)
+ if userID == 0 {
+ response.Error(c, http.StatusUnauthorized, "未登录")
+ return
+ }
+
+ id, err := strconv.ParseInt(c.Param("id"), 10, 64)
+ if err != nil {
+ response.Error(c, http.StatusBadRequest, "无效的 ID")
+ return
+ }
+
+ // 权限校验
+ gen, err := h.genService.GetByID(c.Request.Context(), id, userID)
+ if err != nil {
+ response.Error(c, http.StatusNotFound, err.Error())
+ return
+ }
+ _ = gen
+
+ if err := h.genService.MarkCancelled(c.Request.Context(), id); err != nil {
+ if errors.Is(err, service.ErrSoraGenerationNotActive) {
+ response.Error(c, http.StatusConflict, "任务已结束,无法取消")
+ return
+ }
+ response.Error(c, http.StatusBadRequest, err.Error())
+ return
+ }
+
+ response.Success(c, gin.H{"message": "已取消"})
+}
+
+// SaveToStorage 手动保存 upstream 记录到 S3。
+// POST /api/v1/sora/generations/:id/save
+func (h *SoraClientHandler) SaveToStorage(c *gin.Context) {
+ userID := getUserIDFromContext(c)
+ if userID == 0 {
+ response.Error(c, http.StatusUnauthorized, "未登录")
+ return
+ }
+
+ id, err := strconv.ParseInt(c.Param("id"), 10, 64)
+ if err != nil {
+ response.Error(c, http.StatusBadRequest, "无效的 ID")
+ return
+ }
+
+ gen, err := h.genService.GetByID(c.Request.Context(), id, userID)
+ if err != nil {
+ response.Error(c, http.StatusNotFound, err.Error())
+ return
+ }
+
+ if gen.StorageType != service.SoraStorageTypeUpstream {
+ response.Error(c, http.StatusBadRequest, "仅 upstream 类型的记录可手动保存")
+ return
+ }
+ if gen.MediaURL == "" {
+ response.Error(c, http.StatusBadRequest, "媒体 URL 为空,可能已过期")
+ return
+ }
+
+ if h.s3Storage == nil || !h.s3Storage.Enabled(c.Request.Context()) {
+ response.Error(c, http.StatusServiceUnavailable, "云存储未配置,请联系管理员")
+ return
+ }
+
+ sourceURLs := gen.MediaURLs
+ if len(sourceURLs) == 0 && gen.MediaURL != "" {
+ sourceURLs = []string{gen.MediaURL}
+ }
+ if len(sourceURLs) == 0 {
+ response.Error(c, http.StatusBadRequest, "媒体 URL 为空,可能已过期")
+ return
+ }
+
+ uploadedKeys := make([]string, 0, len(sourceURLs))
+ accessURLs := make([]string, 0, len(sourceURLs))
+ var totalSize int64
+
+ for _, sourceURL := range sourceURLs {
+ objectKey, fileSize, uploadErr := h.s3Storage.UploadFromURL(c.Request.Context(), userID, sourceURL)
+ if uploadErr != nil {
+ if len(uploadedKeys) > 0 {
+ _ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys)
+ }
+ var upstreamErr *service.UpstreamDownloadError
+ if errors.As(uploadErr, &upstreamErr) && (upstreamErr.StatusCode == http.StatusForbidden || upstreamErr.StatusCode == http.StatusNotFound) {
+ response.Error(c, http.StatusGone, "媒体链接已过期,无法保存")
+ return
+ }
+ response.Error(c, http.StatusInternalServerError, "上传到 S3 失败: "+uploadErr.Error())
+ return
+ }
+ accessURL, err := h.s3Storage.GetAccessURL(c.Request.Context(), objectKey)
+ if err != nil {
+ uploadedKeys = append(uploadedKeys, objectKey)
+ _ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys)
+ response.Error(c, http.StatusInternalServerError, "生成 S3 访问链接失败: "+err.Error())
+ return
+ }
+ uploadedKeys = append(uploadedKeys, objectKey)
+ accessURLs = append(accessURLs, accessURL)
+ totalSize += fileSize
+ }
+
+ usageAdded := false
+ if totalSize > 0 && h.quotaService != nil {
+ if err := h.quotaService.AddUsage(c.Request.Context(), userID, totalSize); err != nil {
+ _ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys)
+ var quotaErr *service.QuotaExceededError
+ if errors.As(err, "aErr) {
+ response.Error(c, http.StatusTooManyRequests, "存储配额已满,请删除不需要的作品释放空间")
+ return
+ }
+ response.Error(c, http.StatusInternalServerError, "配额更新失败: "+err.Error())
+ return
+ }
+ usageAdded = true
+ }
+
+ if err := h.genService.UpdateStorageForCompleted(
+ c.Request.Context(),
+ id,
+ accessURLs[0],
+ accessURLs,
+ service.SoraStorageTypeS3,
+ uploadedKeys,
+ totalSize,
+ ); err != nil {
+ _ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys)
+ if usageAdded && h.quotaService != nil {
+ _ = h.quotaService.ReleaseUsage(c.Request.Context(), userID, totalSize)
+ }
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ response.Success(c, gin.H{
+ "message": "已保存到 S3",
+ "object_key": uploadedKeys[0],
+ "object_keys": uploadedKeys,
+ })
+}
+
+// GetStorageStatus 返回存储状态。
+// GET /api/v1/sora/storage-status
+func (h *SoraClientHandler) GetStorageStatus(c *gin.Context) {
+ s3Enabled := h.s3Storage != nil && h.s3Storage.Enabled(c.Request.Context())
+ s3Healthy := false
+ if s3Enabled {
+ s3Healthy = h.s3Storage.IsHealthy(c.Request.Context())
+ }
+ localEnabled := h.mediaStorage != nil && h.mediaStorage.Enabled()
+ response.Success(c, gin.H{
+ "s3_enabled": s3Enabled,
+ "s3_healthy": s3Healthy,
+ "local_enabled": localEnabled,
+ })
+}
+
+func (h *SoraClientHandler) cleanupStoredMedia(ctx context.Context, storageType string, s3Keys []string, localPaths []string) {
+ switch storageType {
+ case service.SoraStorageTypeS3:
+ if h.s3Storage != nil && len(s3Keys) > 0 {
+ if err := h.s3Storage.DeleteObjects(ctx, s3Keys); err != nil {
+ logger.LegacyPrintf("handler.sora_client", "[SoraClient] 清理 S3 文件失败 keys=%v err=%v", s3Keys, err)
+ }
+ }
+ case service.SoraStorageTypeLocal:
+ if h.mediaStorage != nil && len(localPaths) > 0 {
+ if err := h.mediaStorage.DeleteByRelativePaths(localPaths); err != nil {
+ logger.LegacyPrintf("handler.sora_client", "[SoraClient] 清理本地文件失败 paths=%v err=%v", localPaths, err)
+ }
+ }
+ }
+}
+
+// getUserIDFromContext 从 gin 上下文中提取用户 ID。
+func getUserIDFromContext(c *gin.Context) int64 {
+ if subject, ok := middleware2.GetAuthSubjectFromContext(c); ok && subject.UserID > 0 {
+ return subject.UserID
+ }
+
+ if id, ok := c.Get("user_id"); ok {
+ switch v := id.(type) {
+ case int64:
+ return v
+ case float64:
+ return int64(v)
+ case string:
+ n, _ := strconv.ParseInt(v, 10, 64)
+ return n
+ }
+ }
+ // 尝试从 JWT claims 获取
+ if id, ok := c.Get("userID"); ok {
+ if v, ok := id.(int64); ok {
+ return v
+ }
+ }
+ return 0
+}
+
+func groupIDForLog(groupID *int64) int64 {
+ if groupID == nil {
+ return 0
+ }
+ return *groupID
+}
+
+func trimForLog(raw string, maxLen int) string {
+ trimmed := strings.TrimSpace(raw)
+ if maxLen <= 0 || len(trimmed) <= maxLen {
+ return trimmed
+ }
+ return trimmed[:maxLen] + "...(truncated)"
+}
+
+// GetModels 获取可用 Sora 模型家族列表。
+// 优先从上游 Sora API 同步模型列表,失败时降级到本地配置。
+// GET /api/v1/sora/models
+func (h *SoraClientHandler) GetModels(c *gin.Context) {
+ families := h.getModelFamilies(c.Request.Context())
+ response.Success(c, families)
+}
+
+// getModelFamilies 获取模型家族列表(带缓存)。
+func (h *SoraClientHandler) getModelFamilies(ctx context.Context) []service.SoraModelFamily {
+ // 读锁检查缓存
+ h.modelCacheMu.RLock()
+ ttl := modelCacheTTL
+ if !h.modelCacheUpstream {
+ ttl = modelCacheFailedTTL
+ }
+ if h.cachedFamilies != nil && time.Since(h.modelCacheTime) < ttl {
+ families := h.cachedFamilies
+ h.modelCacheMu.RUnlock()
+ return families
+ }
+ h.modelCacheMu.RUnlock()
+
+ // 写锁更新缓存
+ h.modelCacheMu.Lock()
+ defer h.modelCacheMu.Unlock()
+
+ // double-check
+ ttl = modelCacheTTL
+ if !h.modelCacheUpstream {
+ ttl = modelCacheFailedTTL
+ }
+ if h.cachedFamilies != nil && time.Since(h.modelCacheTime) < ttl {
+ return h.cachedFamilies
+ }
+
+ // 尝试从上游获取
+ families, err := h.fetchUpstreamModels(ctx)
+ if err != nil {
+ logger.LegacyPrintf("handler.sora_client", "[SoraClient] 上游模型获取失败,使用本地配置: %v", err)
+ families = service.BuildSoraModelFamilies()
+ h.cachedFamilies = families
+ h.modelCacheTime = time.Now()
+ h.modelCacheUpstream = false
+ return families
+ }
+
+ logger.LegacyPrintf("handler.sora_client", "[SoraClient] 从上游同步到 %d 个模型家族", len(families))
+ h.cachedFamilies = families
+ h.modelCacheTime = time.Now()
+ h.modelCacheUpstream = true
+ return families
+}
+
+// fetchUpstreamModels 从上游 Sora API 获取模型列表。
+func (h *SoraClientHandler) fetchUpstreamModels(ctx context.Context) ([]service.SoraModelFamily, error) {
+ if h.gatewayService == nil {
+ return nil, fmt.Errorf("gatewayService 未初始化")
+ }
+
+ // 设置 ForcePlatform 用于 Sora 账号选择
+ ctx = context.WithValue(ctx, ctxkey.ForcePlatform, service.PlatformSora)
+
+ // 选择一个 Sora 账号
+ account, err := h.gatewayService.SelectAccountForModel(ctx, nil, "", "sora2-landscape-10s")
+ if err != nil {
+ return nil, fmt.Errorf("选择 Sora 账号失败: %w", err)
+ }
+
+ // 仅支持 API Key 类型账号
+ if account.Type != service.AccountTypeAPIKey {
+ return nil, fmt.Errorf("当前账号类型 %s 不支持模型同步", account.Type)
+ }
+
+ apiKey := account.GetCredential("api_key")
+ if apiKey == "" {
+ return nil, fmt.Errorf("账号缺少 api_key")
+ }
+
+ baseURL := account.GetBaseURL()
+ if baseURL == "" {
+ return nil, fmt.Errorf("账号缺少 base_url")
+ }
+
+ // 构建上游模型列表请求
+ modelsURL := strings.TrimRight(baseURL, "/") + "/sora/v1/models"
+
+ reqCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
+ defer cancel()
+
+ req, err := http.NewRequestWithContext(reqCtx, http.MethodGet, modelsURL, nil)
+ if err != nil {
+ return nil, fmt.Errorf("创建请求失败: %w", err)
+ }
+ req.Header.Set("Authorization", "Bearer "+apiKey)
+
+ client := &http.Client{Timeout: 10 * time.Second}
+ resp, err := client.Do(req)
+ if err != nil {
+ return nil, fmt.Errorf("请求上游失败: %w", err)
+ }
+ defer func() {
+ _ = resp.Body.Close()
+ }()
+
+ if resp.StatusCode != http.StatusOK {
+ return nil, fmt.Errorf("上游返回状态码 %d", resp.StatusCode)
+ }
+
+ body, err := io.ReadAll(io.LimitReader(resp.Body, 1*1024*1024))
+ if err != nil {
+ return nil, fmt.Errorf("读取响应失败: %w", err)
+ }
+
+ // 解析 OpenAI 格式的模型列表
+ var modelsResp struct {
+ Data []struct {
+ ID string `json:"id"`
+ } `json:"data"`
+ }
+ if err := json.Unmarshal(body, &modelsResp); err != nil {
+ return nil, fmt.Errorf("解析响应失败: %w", err)
+ }
+
+ if len(modelsResp.Data) == 0 {
+ return nil, fmt.Errorf("上游返回空模型列表")
+ }
+
+ // 提取模型 ID
+ modelIDs := make([]string, 0, len(modelsResp.Data))
+ for _, m := range modelsResp.Data {
+ modelIDs = append(modelIDs, m.ID)
+ }
+
+ // 转换为模型家族
+ families := service.BuildSoraModelFamiliesFromIDs(modelIDs)
+ if len(families) == 0 {
+ return nil, fmt.Errorf("未能从上游模型列表中识别出有效的模型家族")
+ }
+
+ return families, nil
+}
diff --git a/backend/internal/handler/sora_client_handler_test.go b/backend/internal/handler/sora_client_handler_test.go
new file mode 100644
index 00000000..5df7fa0a
--- /dev/null
+++ b/backend/internal/handler/sora_client_handler_test.go
@@ -0,0 +1,3138 @@
+//go:build unit
+
+package handler
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "os"
+ "strings"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
+ middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+)
+
+func init() {
+ gin.SetMode(gin.TestMode)
+}
+
+// ==================== Stub: SoraGenerationRepository ====================
+
+var _ service.SoraGenerationRepository = (*stubSoraGenRepo)(nil)
+
+type stubSoraGenRepo struct {
+ gens map[int64]*service.SoraGeneration
+ nextID int64
+ createErr error
+ getErr error
+ updateErr error
+ deleteErr error
+ listErr error
+ countErr error
+ countValue int64
+
+ // 条件性 Update 失败:前 updateFailAfterN 次成功,之后失败
+ updateCallCount *int32
+ updateFailAfterN int32
+
+ // 条件性 GetByID 状态覆盖:前 getByIDOverrideAfterN 次正常返回,之后返回 overrideStatus
+ getByIDCallCount int32
+ getByIDOverrideAfterN int32 // 0 = 不覆盖
+ getByIDOverrideStatus string
+}
+
+func newStubSoraGenRepo() *stubSoraGenRepo {
+ return &stubSoraGenRepo{gens: make(map[int64]*service.SoraGeneration), nextID: 1}
+}
+
+func (r *stubSoraGenRepo) Create(_ context.Context, gen *service.SoraGeneration) error {
+ if r.createErr != nil {
+ return r.createErr
+ }
+ gen.ID = r.nextID
+ r.nextID++
+ r.gens[gen.ID] = gen
+ return nil
+}
+func (r *stubSoraGenRepo) GetByID(_ context.Context, id int64) (*service.SoraGeneration, error) {
+ if r.getErr != nil {
+ return nil, r.getErr
+ }
+ gen, ok := r.gens[id]
+ if !ok {
+ return nil, fmt.Errorf("not found")
+ }
+ // 条件性状态覆盖:模拟外部取消等场景
+ if r.getByIDOverrideAfterN > 0 {
+ n := atomic.AddInt32(&r.getByIDCallCount, 1)
+ if n > r.getByIDOverrideAfterN {
+ cp := *gen
+ cp.Status = r.getByIDOverrideStatus
+ return &cp, nil
+ }
+ }
+ return gen, nil
+}
+func (r *stubSoraGenRepo) Update(_ context.Context, gen *service.SoraGeneration) error {
+ // 条件性失败:前 N 次成功,之后失败
+ if r.updateCallCount != nil {
+ n := atomic.AddInt32(r.updateCallCount, 1)
+ if n > r.updateFailAfterN {
+ return fmt.Errorf("conditional update error (call #%d)", n)
+ }
+ }
+ if r.updateErr != nil {
+ return r.updateErr
+ }
+ r.gens[gen.ID] = gen
+ return nil
+}
+func (r *stubSoraGenRepo) Delete(_ context.Context, id int64) error {
+ if r.deleteErr != nil {
+ return r.deleteErr
+ }
+ delete(r.gens, id)
+ return nil
+}
+func (r *stubSoraGenRepo) List(_ context.Context, params service.SoraGenerationListParams) ([]*service.SoraGeneration, int64, error) {
+ if r.listErr != nil {
+ return nil, 0, r.listErr
+ }
+ var result []*service.SoraGeneration
+ for _, gen := range r.gens {
+ if gen.UserID != params.UserID {
+ continue
+ }
+ result = append(result, gen)
+ }
+ return result, int64(len(result)), nil
+}
+func (r *stubSoraGenRepo) CountByUserAndStatus(_ context.Context, _ int64, _ []string) (int64, error) {
+ if r.countErr != nil {
+ return 0, r.countErr
+ }
+ return r.countValue, nil
+}
+
+// ==================== 辅助函数 ====================
+
+func newTestSoraClientHandler(repo *stubSoraGenRepo) *SoraClientHandler {
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+ return &SoraClientHandler{genService: genService}
+}
+
+func makeGinContext(method, path, body string, userID int64) (*gin.Context, *httptest.ResponseRecorder) {
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ if body != "" {
+ c.Request = httptest.NewRequest(method, path, strings.NewReader(body))
+ c.Request.Header.Set("Content-Type", "application/json")
+ } else {
+ c.Request = httptest.NewRequest(method, path, nil)
+ }
+ if userID > 0 {
+ c.Set("user_id", userID)
+ }
+ return c, rec
+}
+
+func parseResponse(t *testing.T, rec *httptest.ResponseRecorder) map[string]any {
+ t.Helper()
+ var resp map[string]any
+ require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
+ return resp
+}
+
+// ==================== 纯函数测试: buildAsyncRequestBody ====================
+
+func TestBuildAsyncRequestBody(t *testing.T) {
+ body := buildAsyncRequestBody("sora2-landscape-10s", "一只猫在跳舞", "", 1)
+ var parsed map[string]any
+ require.NoError(t, json.Unmarshal(body, &parsed))
+ require.Equal(t, "sora2-landscape-10s", parsed["model"])
+ require.Equal(t, false, parsed["stream"])
+
+ msgs := parsed["messages"].([]any)
+ require.Len(t, msgs, 1)
+ msg := msgs[0].(map[string]any)
+ require.Equal(t, "user", msg["role"])
+ require.Equal(t, "一只猫在跳舞", msg["content"])
+}
+
+func TestBuildAsyncRequestBody_EmptyPrompt(t *testing.T) {
+ body := buildAsyncRequestBody("gpt-image", "", "", 1)
+ var parsed map[string]any
+ require.NoError(t, json.Unmarshal(body, &parsed))
+ require.Equal(t, "gpt-image", parsed["model"])
+ msgs := parsed["messages"].([]any)
+ msg := msgs[0].(map[string]any)
+ require.Equal(t, "", msg["content"])
+}
+
+func TestBuildAsyncRequestBody_WithImageInput(t *testing.T) {
+ body := buildAsyncRequestBody("gpt-image", "一只猫", "https://example.com/ref.png", 1)
+ var parsed map[string]any
+ require.NoError(t, json.Unmarshal(body, &parsed))
+ require.Equal(t, "https://example.com/ref.png", parsed["image_input"])
+}
+
+func TestBuildAsyncRequestBody_WithVideoCount(t *testing.T) {
+ body := buildAsyncRequestBody("sora2-landscape-10s", "一只猫在跳舞", "", 3)
+ var parsed map[string]any
+ require.NoError(t, json.Unmarshal(body, &parsed))
+ require.Equal(t, float64(3), parsed["video_count"])
+}
+
+func TestNormalizeVideoCount(t *testing.T) {
+ require.Equal(t, 1, normalizeVideoCount("video", 0))
+ require.Equal(t, 2, normalizeVideoCount("video", 2))
+ require.Equal(t, 3, normalizeVideoCount("video", 5))
+ require.Equal(t, 1, normalizeVideoCount("image", 3))
+}
+
+// ==================== 纯函数测试: parseMediaURLsFromBody ====================
+
+func TestParseMediaURLsFromBody_MediaURLs(t *testing.T) {
+ urls := parseMediaURLsFromBody([]byte(`{"media_urls":["https://a.com/1.mp4","https://a.com/2.mp4"]}`))
+ require.Equal(t, []string{"https://a.com/1.mp4", "https://a.com/2.mp4"}, urls)
+}
+
+func TestParseMediaURLsFromBody_SingleMediaURL(t *testing.T) {
+ urls := parseMediaURLsFromBody([]byte(`{"media_url":"https://a.com/video.mp4"}`))
+ require.Equal(t, []string{"https://a.com/video.mp4"}, urls)
+}
+
+func TestParseMediaURLsFromBody_EmptyBody(t *testing.T) {
+ require.Nil(t, parseMediaURLsFromBody(nil))
+ require.Nil(t, parseMediaURLsFromBody([]byte{}))
+}
+
+func TestParseMediaURLsFromBody_InvalidJSON(t *testing.T) {
+ require.Nil(t, parseMediaURLsFromBody([]byte("not json")))
+}
+
+func TestParseMediaURLsFromBody_NoMediaFields(t *testing.T) {
+ require.Nil(t, parseMediaURLsFromBody([]byte(`{"data":"something"}`)))
+}
+
+func TestParseMediaURLsFromBody_EmptyMediaURL(t *testing.T) {
+ require.Nil(t, parseMediaURLsFromBody([]byte(`{"media_url":""}`)))
+}
+
+func TestParseMediaURLsFromBody_EmptyMediaURLs(t *testing.T) {
+ require.Nil(t, parseMediaURLsFromBody([]byte(`{"media_urls":[]}`)))
+}
+
+func TestParseMediaURLsFromBody_MediaURLsPriority(t *testing.T) {
+ body := `{"media_url":"https://single.com/1.mp4","media_urls":["https://multi.com/a.mp4","https://multi.com/b.mp4"]}`
+ urls := parseMediaURLsFromBody([]byte(body))
+ require.Len(t, urls, 2)
+ require.Equal(t, "https://multi.com/a.mp4", urls[0])
+}
+
+func TestParseMediaURLsFromBody_FilterEmpty(t *testing.T) {
+ urls := parseMediaURLsFromBody([]byte(`{"media_urls":["https://a.com/1.mp4","","https://a.com/2.mp4"]}`))
+ require.Equal(t, []string{"https://a.com/1.mp4", "https://a.com/2.mp4"}, urls)
+}
+
+func TestParseMediaURLsFromBody_AllEmpty(t *testing.T) {
+ require.Nil(t, parseMediaURLsFromBody([]byte(`{"media_urls":["",""]}`)))
+}
+
+func TestParseMediaURLsFromBody_NonStringArray(t *testing.T) {
+ // media_urls 不是 string 数组
+ require.Nil(t, parseMediaURLsFromBody([]byte(`{"media_urls":"not-array"}`)))
+}
+
+func TestParseMediaURLsFromBody_MediaURLNotString(t *testing.T) {
+ require.Nil(t, parseMediaURLsFromBody([]byte(`{"media_url":123}`)))
+}
+
+// ==================== 纯函数测试: extractMediaURLsFromResult ====================
+
+func TestExtractMediaURLsFromResult_OAuthPath(t *testing.T) {
+ result := &service.ForwardResult{MediaURL: "https://oauth.com/video.mp4"}
+ recorder := httptest.NewRecorder()
+ url, urls := extractMediaURLsFromResult(result, recorder)
+ require.Equal(t, "https://oauth.com/video.mp4", url)
+ require.Equal(t, []string{"https://oauth.com/video.mp4"}, urls)
+}
+
+func TestExtractMediaURLsFromResult_OAuthWithBody(t *testing.T) {
+ result := &service.ForwardResult{MediaURL: "https://oauth.com/video.mp4"}
+ recorder := httptest.NewRecorder()
+ _, _ = recorder.Write([]byte(`{"media_urls":["https://body.com/1.mp4","https://body.com/2.mp4"]}`))
+ url, urls := extractMediaURLsFromResult(result, recorder)
+ require.Equal(t, "https://body.com/1.mp4", url)
+ require.Len(t, urls, 2)
+}
+
+func TestExtractMediaURLsFromResult_APIKeyPath(t *testing.T) {
+ recorder := httptest.NewRecorder()
+ _, _ = recorder.Write([]byte(`{"media_url":"https://upstream.com/video.mp4"}`))
+ url, urls := extractMediaURLsFromResult(nil, recorder)
+ require.Equal(t, "https://upstream.com/video.mp4", url)
+ require.Equal(t, []string{"https://upstream.com/video.mp4"}, urls)
+}
+
+func TestExtractMediaURLsFromResult_NilResultEmptyBody(t *testing.T) {
+ recorder := httptest.NewRecorder()
+ url, urls := extractMediaURLsFromResult(nil, recorder)
+ require.Empty(t, url)
+ require.Nil(t, urls)
+}
+
+func TestExtractMediaURLsFromResult_EmptyMediaURL(t *testing.T) {
+ result := &service.ForwardResult{MediaURL: ""}
+ recorder := httptest.NewRecorder()
+ url, urls := extractMediaURLsFromResult(result, recorder)
+ require.Empty(t, url)
+ require.Nil(t, urls)
+}
+
+// ==================== getUserIDFromContext ====================
+
+func TestGetUserIDFromContext_Int64(t *testing.T) {
+ c, _ := gin.CreateTestContext(httptest.NewRecorder())
+ c.Request = httptest.NewRequest("GET", "/", nil)
+ c.Set("user_id", int64(42))
+ require.Equal(t, int64(42), getUserIDFromContext(c))
+}
+
+func TestGetUserIDFromContext_AuthSubject(t *testing.T) {
+ c, _ := gin.CreateTestContext(httptest.NewRecorder())
+ c.Request = httptest.NewRequest("GET", "/", nil)
+ c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 777})
+ require.Equal(t, int64(777), getUserIDFromContext(c))
+}
+
+func TestGetUserIDFromContext_Float64(t *testing.T) {
+ c, _ := gin.CreateTestContext(httptest.NewRecorder())
+ c.Request = httptest.NewRequest("GET", "/", nil)
+ c.Set("user_id", float64(99))
+ require.Equal(t, int64(99), getUserIDFromContext(c))
+}
+
+func TestGetUserIDFromContext_String(t *testing.T) {
+ c, _ := gin.CreateTestContext(httptest.NewRecorder())
+ c.Request = httptest.NewRequest("GET", "/", nil)
+ c.Set("user_id", "123")
+ require.Equal(t, int64(123), getUserIDFromContext(c))
+}
+
+func TestGetUserIDFromContext_UserIDFallback(t *testing.T) {
+ c, _ := gin.CreateTestContext(httptest.NewRecorder())
+ c.Request = httptest.NewRequest("GET", "/", nil)
+ c.Set("userID", int64(55))
+ require.Equal(t, int64(55), getUserIDFromContext(c))
+}
+
+func TestGetUserIDFromContext_NoID(t *testing.T) {
+ c, _ := gin.CreateTestContext(httptest.NewRecorder())
+ c.Request = httptest.NewRequest("GET", "/", nil)
+ require.Equal(t, int64(0), getUserIDFromContext(c))
+}
+
+func TestGetUserIDFromContext_InvalidString(t *testing.T) {
+ c, _ := gin.CreateTestContext(httptest.NewRecorder())
+ c.Request = httptest.NewRequest("GET", "/", nil)
+ c.Set("user_id", "not-a-number")
+ require.Equal(t, int64(0), getUserIDFromContext(c))
+}
+
+// ==================== Handler: Generate ====================
+
+func TestGenerate_Unauthorized(t *testing.T) {
+ h := newTestSoraClientHandler(newStubSoraGenRepo())
+ c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 0)
+ h.Generate(c)
+ require.Equal(t, http.StatusUnauthorized, rec.Code)
+}
+
+func TestGenerate_BadRequest_MissingModel(t *testing.T) {
+ h := newTestSoraClientHandler(newStubSoraGenRepo())
+ c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"prompt":"test"}`, 1)
+ h.Generate(c)
+ require.Equal(t, http.StatusBadRequest, rec.Code)
+}
+
+func TestGenerate_BadRequest_MissingPrompt(t *testing.T) {
+ h := newTestSoraClientHandler(newStubSoraGenRepo())
+ c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s"}`, 1)
+ h.Generate(c)
+ require.Equal(t, http.StatusBadRequest, rec.Code)
+}
+
+func TestGenerate_BadRequest_InvalidJSON(t *testing.T) {
+ h := newTestSoraClientHandler(newStubSoraGenRepo())
+ c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{invalid`, 1)
+ h.Generate(c)
+ require.Equal(t, http.StatusBadRequest, rec.Code)
+}
+
+func TestGenerate_TooManyRequests(t *testing.T) {
+ repo := newStubSoraGenRepo()
+ repo.countValue = 3
+ h := newTestSoraClientHandler(repo)
+ c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
+ h.Generate(c)
+ require.Equal(t, http.StatusTooManyRequests, rec.Code)
+}
+
+func TestGenerate_CountError(t *testing.T) {
+ repo := newStubSoraGenRepo()
+ repo.countErr = fmt.Errorf("db error")
+ h := newTestSoraClientHandler(repo)
+ c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
+ h.Generate(c)
+ require.Equal(t, http.StatusInternalServerError, rec.Code)
+}
+
+func TestGenerate_Success(t *testing.T) {
+ repo := newStubSoraGenRepo()
+ h := newTestSoraClientHandler(repo)
+ c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"测试生成"}`, 1)
+ h.Generate(c)
+ require.Equal(t, http.StatusOK, rec.Code)
+ resp := parseResponse(t, rec)
+ data := resp["data"].(map[string]any)
+ require.NotZero(t, data["generation_id"])
+ require.Equal(t, "pending", data["status"])
+}
+
+func TestGenerate_DefaultMediaType(t *testing.T) {
+ repo := newStubSoraGenRepo()
+ h := newTestSoraClientHandler(repo)
+ c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
+ h.Generate(c)
+ require.Equal(t, http.StatusOK, rec.Code)
+ require.Equal(t, "video", repo.gens[1].MediaType)
+}
+
+func TestGenerate_ImageMediaType(t *testing.T) {
+ repo := newStubSoraGenRepo()
+ h := newTestSoraClientHandler(repo)
+ c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"gpt-image","prompt":"test","media_type":"image"}`, 1)
+ h.Generate(c)
+ require.Equal(t, http.StatusOK, rec.Code)
+ require.Equal(t, "image", repo.gens[1].MediaType)
+}
+
+func TestGenerate_CreatePendingError(t *testing.T) {
+ repo := newStubSoraGenRepo()
+ repo.createErr = fmt.Errorf("create failed")
+ h := newTestSoraClientHandler(repo)
+ c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
+ h.Generate(c)
+ require.Equal(t, http.StatusInternalServerError, rec.Code)
+}
+
+func TestGenerate_NilQuotaServiceSkipsCheck(t *testing.T) {
+ repo := newStubSoraGenRepo()
+ h := newTestSoraClientHandler(repo)
+ c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
+ h.Generate(c)
+ require.Equal(t, http.StatusOK, rec.Code)
+}
+
+func TestGenerate_APIKeyInContext(t *testing.T) {
+ repo := newStubSoraGenRepo()
+ h := newTestSoraClientHandler(repo)
+ c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
+ c.Set("api_key_id", int64(42))
+ h.Generate(c)
+ require.Equal(t, http.StatusOK, rec.Code)
+ require.NotNil(t, repo.gens[1].APIKeyID)
+ require.Equal(t, int64(42), *repo.gens[1].APIKeyID)
+}
+
+func TestGenerate_NoAPIKeyInContext(t *testing.T) {
+ repo := newStubSoraGenRepo()
+ h := newTestSoraClientHandler(repo)
+ c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
+ h.Generate(c)
+ require.Equal(t, http.StatusOK, rec.Code)
+ require.Nil(t, repo.gens[1].APIKeyID)
+}
+
+func TestGenerate_ConcurrencyBoundary(t *testing.T) {
+ // activeCount == 2 应该允许
+ repo := newStubSoraGenRepo()
+ repo.countValue = 2
+ h := newTestSoraClientHandler(repo)
+ c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
+ h.Generate(c)
+ require.Equal(t, http.StatusOK, rec.Code)
+}
+
+// ==================== Handler: ListGenerations ====================
+
+func TestListGenerations_Unauthorized(t *testing.T) {
+ h := newTestSoraClientHandler(newStubSoraGenRepo())
+ c, rec := makeGinContext("GET", "/api/v1/sora/generations", "", 0)
+ h.ListGenerations(c)
+ require.Equal(t, http.StatusUnauthorized, rec.Code)
+}
+
+func TestListGenerations_Success(t *testing.T) {
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Model: "sora2-landscape-10s", Status: "completed", StorageType: "upstream"}
+ repo.gens[2] = &service.SoraGeneration{ID: 2, UserID: 1, Model: "gpt-image", Status: "pending", StorageType: "none"}
+ repo.nextID = 3
+ h := newTestSoraClientHandler(repo)
+ c, rec := makeGinContext("GET", "/api/v1/sora/generations?page=1&page_size=10", "", 1)
+ h.ListGenerations(c)
+ require.Equal(t, http.StatusOK, rec.Code)
+ resp := parseResponse(t, rec)
+ data := resp["data"].(map[string]any)
+ items := data["data"].([]any)
+ require.Len(t, items, 2)
+ require.Equal(t, float64(2), data["total"])
+}
+
+func TestListGenerations_ListError(t *testing.T) {
+ repo := newStubSoraGenRepo()
+ repo.listErr = fmt.Errorf("db error")
+ h := newTestSoraClientHandler(repo)
+ c, rec := makeGinContext("GET", "/api/v1/sora/generations", "", 1)
+ h.ListGenerations(c)
+ require.Equal(t, http.StatusInternalServerError, rec.Code)
+}
+
+func TestListGenerations_DefaultPagination(t *testing.T) {
+ repo := newStubSoraGenRepo()
+ h := newTestSoraClientHandler(repo)
+ // 不传分页参数,应默认 page=1 page_size=20
+ c, rec := makeGinContext("GET", "/api/v1/sora/generations", "", 1)
+ h.ListGenerations(c)
+ require.Equal(t, http.StatusOK, rec.Code)
+ resp := parseResponse(t, rec)
+ data := resp["data"].(map[string]any)
+ require.Equal(t, float64(1), data["page"])
+}
+
+// ==================== Handler: GetGeneration ====================
+
+func TestGetGeneration_Unauthorized(t *testing.T) {
+ h := newTestSoraClientHandler(newStubSoraGenRepo())
+ c, rec := makeGinContext("GET", "/api/v1/sora/generations/1", "", 0)
+ c.Params = gin.Params{{Key: "id", Value: "1"}}
+ h.GetGeneration(c)
+ require.Equal(t, http.StatusUnauthorized, rec.Code)
+}
+
+func TestGetGeneration_InvalidID(t *testing.T) {
+ h := newTestSoraClientHandler(newStubSoraGenRepo())
+ c, rec := makeGinContext("GET", "/api/v1/sora/generations/abc", "", 1)
+ c.Params = gin.Params{{Key: "id", Value: "abc"}}
+ h.GetGeneration(c)
+ require.Equal(t, http.StatusBadRequest, rec.Code)
+}
+
+func TestGetGeneration_NotFound(t *testing.T) {
+ h := newTestSoraClientHandler(newStubSoraGenRepo())
+ c, rec := makeGinContext("GET", "/api/v1/sora/generations/999", "", 1)
+ c.Params = gin.Params{{Key: "id", Value: "999"}}
+ h.GetGeneration(c)
+ require.Equal(t, http.StatusNotFound, rec.Code)
+}
+
+func TestGetGeneration_WrongUser(t *testing.T) {
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 2, Status: "completed"}
+ h := newTestSoraClientHandler(repo)
+ c, rec := makeGinContext("GET", "/api/v1/sora/generations/1", "", 1)
+ c.Params = gin.Params{{Key: "id", Value: "1"}}
+ h.GetGeneration(c)
+ require.Equal(t, http.StatusNotFound, rec.Code)
+}
+
+func TestGetGeneration_Success(t *testing.T) {
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Model: "sora2-landscape-10s", Status: "completed", StorageType: "upstream", MediaURL: "https://example.com/video.mp4"}
+ h := newTestSoraClientHandler(repo)
+ c, rec := makeGinContext("GET", "/api/v1/sora/generations/1", "", 1)
+ c.Params = gin.Params{{Key: "id", Value: "1"}}
+ h.GetGeneration(c)
+ require.Equal(t, http.StatusOK, rec.Code)
+ resp := parseResponse(t, rec)
+ data := resp["data"].(map[string]any)
+ require.Equal(t, float64(1), data["id"])
+}
+
+// ==================== Handler: DeleteGeneration ====================
+
+func TestDeleteGeneration_Unauthorized(t *testing.T) {
+ h := newTestSoraClientHandler(newStubSoraGenRepo())
+ c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 0)
+ c.Params = gin.Params{{Key: "id", Value: "1"}}
+ h.DeleteGeneration(c)
+ require.Equal(t, http.StatusUnauthorized, rec.Code)
+}
+
+func TestDeleteGeneration_InvalidID(t *testing.T) {
+ h := newTestSoraClientHandler(newStubSoraGenRepo())
+ c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/abc", "", 1)
+ c.Params = gin.Params{{Key: "id", Value: "abc"}}
+ h.DeleteGeneration(c)
+ require.Equal(t, http.StatusBadRequest, rec.Code)
+}
+
+func TestDeleteGeneration_NotFound(t *testing.T) {
+ h := newTestSoraClientHandler(newStubSoraGenRepo())
+ c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/999", "", 1)
+ c.Params = gin.Params{{Key: "id", Value: "999"}}
+ h.DeleteGeneration(c)
+ require.Equal(t, http.StatusNotFound, rec.Code)
+}
+
+func TestDeleteGeneration_WrongUser(t *testing.T) {
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 2, Status: "completed"}
+ h := newTestSoraClientHandler(repo)
+ c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1)
+ c.Params = gin.Params{{Key: "id", Value: "1"}}
+ h.DeleteGeneration(c)
+ require.Equal(t, http.StatusNotFound, rec.Code)
+}
+
+func TestDeleteGeneration_Success(t *testing.T) {
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed"}
+ h := newTestSoraClientHandler(repo)
+ c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1)
+ c.Params = gin.Params{{Key: "id", Value: "1"}}
+ h.DeleteGeneration(c)
+ require.Equal(t, http.StatusOK, rec.Code)
+ _, exists := repo.gens[1]
+ require.False(t, exists)
+}
+
+// ==================== Handler: CancelGeneration ====================
+
+func TestCancelGeneration_Unauthorized(t *testing.T) {
+ h := newTestSoraClientHandler(newStubSoraGenRepo())
+ c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 0)
+ c.Params = gin.Params{{Key: "id", Value: "1"}}
+ h.CancelGeneration(c)
+ require.Equal(t, http.StatusUnauthorized, rec.Code)
+}
+
+func TestCancelGeneration_InvalidID(t *testing.T) {
+ h := newTestSoraClientHandler(newStubSoraGenRepo())
+ c, rec := makeGinContext("POST", "/api/v1/sora/generations/abc/cancel", "", 1)
+ c.Params = gin.Params{{Key: "id", Value: "abc"}}
+ h.CancelGeneration(c)
+ require.Equal(t, http.StatusBadRequest, rec.Code)
+}
+
+func TestCancelGeneration_NotFound(t *testing.T) {
+ h := newTestSoraClientHandler(newStubSoraGenRepo())
+ c, rec := makeGinContext("POST", "/api/v1/sora/generations/999/cancel", "", 1)
+ c.Params = gin.Params{{Key: "id", Value: "999"}}
+ h.CancelGeneration(c)
+ require.Equal(t, http.StatusNotFound, rec.Code)
+}
+
+func TestCancelGeneration_WrongUser(t *testing.T) {
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 2, Status: "pending"}
+ h := newTestSoraClientHandler(repo)
+ c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1)
+ c.Params = gin.Params{{Key: "id", Value: "1"}}
+ h.CancelGeneration(c)
+ require.Equal(t, http.StatusNotFound, rec.Code)
+}
+
+func TestCancelGeneration_Pending(t *testing.T) {
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
+ h := newTestSoraClientHandler(repo)
+ c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1)
+ c.Params = gin.Params{{Key: "id", Value: "1"}}
+ h.CancelGeneration(c)
+ require.Equal(t, http.StatusOK, rec.Code)
+ require.Equal(t, "cancelled", repo.gens[1].Status)
+}
+
+func TestCancelGeneration_Generating(t *testing.T) {
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "generating"}
+ h := newTestSoraClientHandler(repo)
+ c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1)
+ c.Params = gin.Params{{Key: "id", Value: "1"}}
+ h.CancelGeneration(c)
+ require.Equal(t, http.StatusOK, rec.Code)
+ require.Equal(t, "cancelled", repo.gens[1].Status)
+}
+
+func TestCancelGeneration_Completed(t *testing.T) {
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed"}
+ h := newTestSoraClientHandler(repo)
+ c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1)
+ c.Params = gin.Params{{Key: "id", Value: "1"}}
+ h.CancelGeneration(c)
+ require.Equal(t, http.StatusConflict, rec.Code)
+}
+
+func TestCancelGeneration_Failed(t *testing.T) {
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "failed"}
+ h := newTestSoraClientHandler(repo)
+ c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1)
+ c.Params = gin.Params{{Key: "id", Value: "1"}}
+ h.CancelGeneration(c)
+ require.Equal(t, http.StatusConflict, rec.Code)
+}
+
+func TestCancelGeneration_Cancelled(t *testing.T) {
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "cancelled"}
+ h := newTestSoraClientHandler(repo)
+ c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1)
+ c.Params = gin.Params{{Key: "id", Value: "1"}}
+ h.CancelGeneration(c)
+ require.Equal(t, http.StatusConflict, rec.Code)
+}
+
+// ==================== Handler: GetQuota ====================
+
+func TestGetQuota_Unauthorized(t *testing.T) {
+ h := newTestSoraClientHandler(newStubSoraGenRepo())
+ c, rec := makeGinContext("GET", "/api/v1/sora/quota", "", 0)
+ h.GetQuota(c)
+ require.Equal(t, http.StatusUnauthorized, rec.Code)
+}
+
+func TestGetQuota_NilQuotaService(t *testing.T) {
+ h := newTestSoraClientHandler(newStubSoraGenRepo())
+ c, rec := makeGinContext("GET", "/api/v1/sora/quota", "", 1)
+ h.GetQuota(c)
+ require.Equal(t, http.StatusOK, rec.Code)
+ resp := parseResponse(t, rec)
+ data := resp["data"].(map[string]any)
+ require.Equal(t, "unlimited", data["source"])
+}
+
+// ==================== Handler: GetModels ====================
+
+func TestGetModels(t *testing.T) {
+ h := newTestSoraClientHandler(newStubSoraGenRepo())
+ c, rec := makeGinContext("GET", "/api/v1/sora/models", "", 0)
+ h.GetModels(c)
+ require.Equal(t, http.StatusOK, rec.Code)
+ resp := parseResponse(t, rec)
+ data := resp["data"].([]any)
+ require.Len(t, data, 4)
+ // 验证类型分布
+ videoCount, imageCount := 0, 0
+ for _, item := range data {
+ m := item.(map[string]any)
+ if m["type"] == "video" {
+ videoCount++
+ } else if m["type"] == "image" {
+ imageCount++
+ }
+ }
+ require.Equal(t, 3, videoCount)
+ require.Equal(t, 1, imageCount)
+}
+
+// ==================== Handler: GetStorageStatus ====================
+
+func TestGetStorageStatus_NilS3(t *testing.T) {
+ h := newTestSoraClientHandler(newStubSoraGenRepo())
+ c, rec := makeGinContext("GET", "/api/v1/sora/storage-status", "", 0)
+ h.GetStorageStatus(c)
+ require.Equal(t, http.StatusOK, rec.Code)
+ resp := parseResponse(t, rec)
+ data := resp["data"].(map[string]any)
+ require.Equal(t, false, data["s3_enabled"])
+ require.Equal(t, false, data["s3_healthy"])
+ require.Equal(t, false, data["local_enabled"])
+}
+
+func TestGetStorageStatus_LocalEnabled(t *testing.T) {
+ tmpDir, err := os.MkdirTemp("", "sora-storage-status-*")
+ require.NoError(t, err)
+ defer os.RemoveAll(tmpDir)
+
+ cfg := &config.Config{
+ Sora: config.SoraConfig{
+ Storage: config.SoraStorageConfig{
+ Type: "local",
+ LocalPath: tmpDir,
+ },
+ },
+ }
+ mediaStorage := service.NewSoraMediaStorage(cfg)
+ h := &SoraClientHandler{mediaStorage: mediaStorage}
+
+ c, rec := makeGinContext("GET", "/api/v1/sora/storage-status", "", 0)
+ h.GetStorageStatus(c)
+ require.Equal(t, http.StatusOK, rec.Code)
+ resp := parseResponse(t, rec)
+ data := resp["data"].(map[string]any)
+ require.Equal(t, false, data["s3_enabled"])
+ require.Equal(t, false, data["s3_healthy"])
+ require.Equal(t, true, data["local_enabled"])
+}
+
+// ==================== Handler: SaveToStorage ====================
+
+func TestSaveToStorage_Unauthorized(t *testing.T) {
+ h := newTestSoraClientHandler(newStubSoraGenRepo())
+ c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 0)
+ c.Params = gin.Params{{Key: "id", Value: "1"}}
+ h.SaveToStorage(c)
+ require.Equal(t, http.StatusUnauthorized, rec.Code)
+}
+
+func TestSaveToStorage_InvalidID(t *testing.T) {
+ h := newTestSoraClientHandler(newStubSoraGenRepo())
+ c, rec := makeGinContext("POST", "/api/v1/sora/generations/abc/save", "", 1)
+ c.Params = gin.Params{{Key: "id", Value: "abc"}}
+ h.SaveToStorage(c)
+ require.Equal(t, http.StatusBadRequest, rec.Code)
+}
+
+func TestSaveToStorage_NotFound(t *testing.T) {
+ h := newTestSoraClientHandler(newStubSoraGenRepo())
+ c, rec := makeGinContext("POST", "/api/v1/sora/generations/999/save", "", 1)
+ c.Params = gin.Params{{Key: "id", Value: "999"}}
+ h.SaveToStorage(c)
+ require.Equal(t, http.StatusNotFound, rec.Code)
+}
+
+func TestSaveToStorage_NotUpstream(t *testing.T) {
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed", StorageType: "s3", MediaURL: "https://example.com/v.mp4"}
+ h := newTestSoraClientHandler(repo)
+ c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
+ c.Params = gin.Params{{Key: "id", Value: "1"}}
+ h.SaveToStorage(c)
+ require.Equal(t, http.StatusBadRequest, rec.Code)
+}
+
+func TestSaveToStorage_EmptyMediaURL(t *testing.T) {
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed", StorageType: "upstream", MediaURL: ""}
+ h := newTestSoraClientHandler(repo)
+ c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
+ c.Params = gin.Params{{Key: "id", Value: "1"}}
+ h.SaveToStorage(c)
+ require.Equal(t, http.StatusBadRequest, rec.Code)
+}
+
+func TestSaveToStorage_S3Nil(t *testing.T) {
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed", StorageType: "upstream", MediaURL: "https://example.com/video.mp4"}
+ h := newTestSoraClientHandler(repo)
+ c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
+ c.Params = gin.Params{{Key: "id", Value: "1"}}
+ h.SaveToStorage(c)
+ require.Equal(t, http.StatusServiceUnavailable, rec.Code)
+ resp := parseResponse(t, rec)
+ require.Contains(t, fmt.Sprint(resp["message"]), "云存储")
+}
+
+func TestSaveToStorage_WrongUser(t *testing.T) {
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 2, Status: "completed", StorageType: "upstream", MediaURL: "https://example.com/video.mp4"}
+ h := newTestSoraClientHandler(repo)
+ c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
+ c.Params = gin.Params{{Key: "id", Value: "1"}}
+ h.SaveToStorage(c)
+ require.Equal(t, http.StatusNotFound, rec.Code)
+}
+
+// ==================== storeMediaWithDegradation — nil guard 路径 ====================
+
+func TestStoreMediaWithDegradation_NilS3NilMedia(t *testing.T) {
+ h := &SoraClientHandler{}
+ url, urls, storageType, keys, size := h.storeMediaWithDegradation(
+ context.Background(), 1, "video", "https://upstream.com/v.mp4", nil,
+ )
+ require.Equal(t, service.SoraStorageTypeUpstream, storageType)
+ require.Equal(t, "https://upstream.com/v.mp4", url)
+ require.Equal(t, []string{"https://upstream.com/v.mp4"}, urls)
+ require.Nil(t, keys)
+ require.Equal(t, int64(0), size)
+}
+
+func TestStoreMediaWithDegradation_NilGuardsMultiURL(t *testing.T) {
+ h := &SoraClientHandler{}
+ url, urls, storageType, keys, size := h.storeMediaWithDegradation(
+ context.Background(), 1, "video", "https://upstream.com/v.mp4", []string{"https://a.com/1.mp4", "https://a.com/2.mp4"},
+ )
+ require.Equal(t, service.SoraStorageTypeUpstream, storageType)
+ require.Equal(t, "https://a.com/1.mp4", url)
+ require.Equal(t, []string{"https://a.com/1.mp4", "https://a.com/2.mp4"}, urls)
+ require.Nil(t, keys)
+ require.Equal(t, int64(0), size)
+}
+
+func TestStoreMediaWithDegradation_EmptyMediaURLsFallback(t *testing.T) {
+ h := &SoraClientHandler{}
+ url, _, storageType, _, _ := h.storeMediaWithDegradation(
+ context.Background(), 1, "video", "https://upstream.com/v.mp4", []string{},
+ )
+ require.Equal(t, service.SoraStorageTypeUpstream, storageType)
+ require.Equal(t, "https://upstream.com/v.mp4", url)
+}
+
+// ==================== Stub: UserRepository (用于 SoraQuotaService) ====================
+
+var _ service.UserRepository = (*stubUserRepoForHandler)(nil)
+
+type stubUserRepoForHandler struct {
+ users map[int64]*service.User
+ updateErr error
+}
+
+func newStubUserRepoForHandler() *stubUserRepoForHandler {
+ return &stubUserRepoForHandler{users: make(map[int64]*service.User)}
+}
+
+func (r *stubUserRepoForHandler) GetByID(_ context.Context, id int64) (*service.User, error) {
+ if u, ok := r.users[id]; ok {
+ return u, nil
+ }
+ return nil, fmt.Errorf("user not found")
+}
+func (r *stubUserRepoForHandler) Update(_ context.Context, user *service.User) error {
+ if r.updateErr != nil {
+ return r.updateErr
+ }
+ r.users[user.ID] = user
+ return nil
+}
+func (r *stubUserRepoForHandler) Create(context.Context, *service.User) error { return nil }
+func (r *stubUserRepoForHandler) GetByEmail(context.Context, string) (*service.User, error) {
+ return nil, nil
+}
+func (r *stubUserRepoForHandler) GetFirstAdmin(context.Context) (*service.User, error) {
+ return nil, nil
+}
+func (r *stubUserRepoForHandler) Delete(context.Context, int64) error { return nil }
+func (r *stubUserRepoForHandler) List(context.Context, pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) {
+ return nil, nil, nil
+}
+func (r *stubUserRepoForHandler) ListWithFilters(context.Context, pagination.PaginationParams, service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) {
+ return nil, nil, nil
+}
+func (r *stubUserRepoForHandler) UpdateBalance(context.Context, int64, float64) error { return nil }
+func (r *stubUserRepoForHandler) DeductBalance(context.Context, int64, float64) error { return nil }
+func (r *stubUserRepoForHandler) UpdateConcurrency(context.Context, int64, int) error { return nil }
+func (r *stubUserRepoForHandler) ExistsByEmail(context.Context, string) (bool, error) {
+ return false, nil
+}
+func (r *stubUserRepoForHandler) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
+ return 0, nil
+}
+func (r *stubUserRepoForHandler) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
+func (r *stubUserRepoForHandler) EnableTotp(context.Context, int64) error { return nil }
+func (r *stubUserRepoForHandler) DisableTotp(context.Context, int64) error { return nil }
+func (r *stubUserRepoForHandler) AddGroupToAllowedGroups(context.Context, int64, int64) error {
+ return nil
+}
+
+// ==================== NewSoraClientHandler ====================
+
+func TestNewSoraClientHandler(t *testing.T) {
+ h := NewSoraClientHandler(nil, nil, nil, nil, nil, nil, nil)
+ require.NotNil(t, h)
+}
+
+func TestNewSoraClientHandler_WithAPIKeyService(t *testing.T) {
+ h := NewSoraClientHandler(nil, nil, nil, nil, nil, nil, nil)
+ require.NotNil(t, h)
+ require.Nil(t, h.apiKeyService)
+}
+
+// ==================== Stub: APIKeyRepository (用于 API Key 校验测试) ====================
+
+var _ service.APIKeyRepository = (*stubAPIKeyRepoForHandler)(nil)
+
+type stubAPIKeyRepoForHandler struct {
+ keys map[int64]*service.APIKey
+ getErr error
+}
+
+func newStubAPIKeyRepoForHandler() *stubAPIKeyRepoForHandler {
+ return &stubAPIKeyRepoForHandler{keys: make(map[int64]*service.APIKey)}
+}
+
+func (r *stubAPIKeyRepoForHandler) GetByID(_ context.Context, id int64) (*service.APIKey, error) {
+ if r.getErr != nil {
+ return nil, r.getErr
+ }
+ if k, ok := r.keys[id]; ok {
+ return k, nil
+ }
+ return nil, fmt.Errorf("api key not found: %d", id)
+}
+func (r *stubAPIKeyRepoForHandler) Create(context.Context, *service.APIKey) error { return nil }
+func (r *stubAPIKeyRepoForHandler) GetKeyAndOwnerID(_ context.Context, _ int64) (string, int64, error) {
+ return "", 0, nil
+}
+func (r *stubAPIKeyRepoForHandler) GetByKey(context.Context, string) (*service.APIKey, error) {
+ return nil, nil
+}
+func (r *stubAPIKeyRepoForHandler) GetByKeyForAuth(context.Context, string) (*service.APIKey, error) {
+ return nil, nil
+}
+func (r *stubAPIKeyRepoForHandler) Update(context.Context, *service.APIKey) error { return nil }
+func (r *stubAPIKeyRepoForHandler) Delete(context.Context, int64) error { return nil }
+func (r *stubAPIKeyRepoForHandler) ListByUserID(_ context.Context, _ int64, _ pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
+ return nil, nil, nil
+}
+func (r *stubAPIKeyRepoForHandler) VerifyOwnership(context.Context, int64, []int64) ([]int64, error) {
+ return nil, nil
+}
+func (r *stubAPIKeyRepoForHandler) CountByUserID(context.Context, int64) (int64, error) {
+ return 0, nil
+}
+func (r *stubAPIKeyRepoForHandler) ExistsByKey(context.Context, string) (bool, error) {
+ return false, nil
+}
+func (r *stubAPIKeyRepoForHandler) ListByGroupID(_ context.Context, _ int64, _ pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
+ return nil, nil, nil
+}
+func (r *stubAPIKeyRepoForHandler) SearchAPIKeys(context.Context, int64, string, int) ([]service.APIKey, error) {
+ return nil, nil
+}
+func (r *stubAPIKeyRepoForHandler) ClearGroupIDByGroupID(context.Context, int64) (int64, error) {
+ return 0, nil
+}
+func (r *stubAPIKeyRepoForHandler) CountByGroupID(context.Context, int64) (int64, error) {
+ return 0, nil
+}
+func (r *stubAPIKeyRepoForHandler) ListKeysByUserID(context.Context, int64) ([]string, error) {
+ return nil, nil
+}
+func (r *stubAPIKeyRepoForHandler) ListKeysByGroupID(context.Context, int64) ([]string, error) {
+ return nil, nil
+}
+func (r *stubAPIKeyRepoForHandler) IncrementQuotaUsed(_ context.Context, _ int64, _ float64) (float64, error) {
+ return 0, nil
+}
+func (r *stubAPIKeyRepoForHandler) UpdateLastUsed(context.Context, int64, time.Time) error {
+ return nil
+}
+
+// newTestAPIKeyService 创建测试用的 APIKeyService
+func newTestAPIKeyService(repo *stubAPIKeyRepoForHandler) *service.APIKeyService {
+ return service.NewAPIKeyService(repo, nil, nil, nil, nil, nil, &config.Config{})
+}
+
+// ==================== Generate: API Key 校验(前端传递 api_key_id)====================
+
+func TestGenerate_WithAPIKeyID_Success(t *testing.T) {
+ // 前端传递 api_key_id,校验通过 → 成功生成,记录关联 api_key_id
+ repo := newStubSoraGenRepo()
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+
+ groupID := int64(5)
+ apiKeyRepo := newStubAPIKeyRepoForHandler()
+ apiKeyRepo.keys[42] = &service.APIKey{
+ ID: 42,
+ UserID: 1,
+ Status: service.StatusAPIKeyActive,
+ GroupID: &groupID,
+ }
+ apiKeyService := newTestAPIKeyService(apiKeyRepo)
+
+ h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
+ c, rec := makeGinContext("POST", "/api/v1/sora/generate",
+ `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1)
+ h.Generate(c)
+ require.Equal(t, http.StatusOK, rec.Code)
+ resp := parseResponse(t, rec)
+ data := resp["data"].(map[string]any)
+ require.NotZero(t, data["generation_id"])
+
+ // 验证 api_key_id 已关联到生成记录
+ gen := repo.gens[1]
+ require.NotNil(t, gen.APIKeyID)
+ require.Equal(t, int64(42), *gen.APIKeyID)
+}
+
+func TestGenerate_WithAPIKeyID_NotFound(t *testing.T) {
+ // 前端传递不存在的 api_key_id → 400
+ repo := newStubSoraGenRepo()
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+
+ apiKeyRepo := newStubAPIKeyRepoForHandler()
+ apiKeyService := newTestAPIKeyService(apiKeyRepo)
+
+ h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
+ c, rec := makeGinContext("POST", "/api/v1/sora/generate",
+ `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":999}`, 1)
+ h.Generate(c)
+ require.Equal(t, http.StatusBadRequest, rec.Code)
+ resp := parseResponse(t, rec)
+ require.Contains(t, fmt.Sprint(resp["message"]), "不存在")
+}
+
+func TestGenerate_WithAPIKeyID_WrongUser(t *testing.T) {
+ // 前端传递别人的 api_key_id → 403
+ repo := newStubSoraGenRepo()
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+
+ apiKeyRepo := newStubAPIKeyRepoForHandler()
+ apiKeyRepo.keys[42] = &service.APIKey{
+ ID: 42,
+ UserID: 999, // 属于 user 999
+ Status: service.StatusAPIKeyActive,
+ }
+ apiKeyService := newTestAPIKeyService(apiKeyRepo)
+
+ h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
+ c, rec := makeGinContext("POST", "/api/v1/sora/generate",
+ `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1)
+ h.Generate(c)
+ require.Equal(t, http.StatusForbidden, rec.Code)
+ resp := parseResponse(t, rec)
+ require.Contains(t, fmt.Sprint(resp["message"]), "不属于")
+}
+
+func TestGenerate_WithAPIKeyID_Disabled(t *testing.T) {
+ // 前端传递已禁用的 api_key_id → 403
+ repo := newStubSoraGenRepo()
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+
+ apiKeyRepo := newStubAPIKeyRepoForHandler()
+ apiKeyRepo.keys[42] = &service.APIKey{
+ ID: 42,
+ UserID: 1,
+ Status: service.StatusAPIKeyDisabled,
+ }
+ apiKeyService := newTestAPIKeyService(apiKeyRepo)
+
+ h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
+ c, rec := makeGinContext("POST", "/api/v1/sora/generate",
+ `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1)
+ h.Generate(c)
+ require.Equal(t, http.StatusForbidden, rec.Code)
+ resp := parseResponse(t, rec)
+ require.Contains(t, fmt.Sprint(resp["message"]), "不可用")
+}
+
+func TestGenerate_WithAPIKeyID_QuotaExhausted(t *testing.T) {
+ // 前端传递配额耗尽的 api_key_id → 403
+ repo := newStubSoraGenRepo()
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+
+ apiKeyRepo := newStubAPIKeyRepoForHandler()
+ apiKeyRepo.keys[42] = &service.APIKey{
+ ID: 42,
+ UserID: 1,
+ Status: service.StatusAPIKeyQuotaExhausted,
+ }
+ apiKeyService := newTestAPIKeyService(apiKeyRepo)
+
+ h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
+ c, rec := makeGinContext("POST", "/api/v1/sora/generate",
+ `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1)
+ h.Generate(c)
+ require.Equal(t, http.StatusForbidden, rec.Code)
+}
+
+func TestGenerate_WithAPIKeyID_Expired(t *testing.T) {
+ // 前端传递已过期的 api_key_id → 403
+ repo := newStubSoraGenRepo()
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+
+ apiKeyRepo := newStubAPIKeyRepoForHandler()
+ apiKeyRepo.keys[42] = &service.APIKey{
+ ID: 42,
+ UserID: 1,
+ Status: service.StatusAPIKeyExpired,
+ }
+ apiKeyService := newTestAPIKeyService(apiKeyRepo)
+
+ h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
+ c, rec := makeGinContext("POST", "/api/v1/sora/generate",
+ `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1)
+ h.Generate(c)
+ require.Equal(t, http.StatusForbidden, rec.Code)
+}
+
+func TestGenerate_WithAPIKeyID_NilAPIKeyService(t *testing.T) {
+ // apiKeyService 为 nil 时忽略 api_key_id → 正常生成但不记录 api_key_id
+ repo := newStubSoraGenRepo()
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+
+ h := &SoraClientHandler{genService: genService} // apiKeyService = nil
+ c, rec := makeGinContext("POST", "/api/v1/sora/generate",
+ `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1)
+ h.Generate(c)
+ require.Equal(t, http.StatusOK, rec.Code)
+ // apiKeyService 为 nil → 跳过校验 → api_key_id 不记录
+ require.Nil(t, repo.gens[1].APIKeyID)
+}
+
+func TestGenerate_WithAPIKeyID_NilGroupID(t *testing.T) {
+ // api_key 有效但 GroupID 为 nil → 成功,groupID 为 nil
+ repo := newStubSoraGenRepo()
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+
+ apiKeyRepo := newStubAPIKeyRepoForHandler()
+ apiKeyRepo.keys[42] = &service.APIKey{
+ ID: 42,
+ UserID: 1,
+ Status: service.StatusAPIKeyActive,
+ GroupID: nil, // 无分组
+ }
+ apiKeyService := newTestAPIKeyService(apiKeyRepo)
+
+ h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
+ c, rec := makeGinContext("POST", "/api/v1/sora/generate",
+ `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1)
+ h.Generate(c)
+ require.Equal(t, http.StatusOK, rec.Code)
+ require.NotNil(t, repo.gens[1].APIKeyID)
+ require.Equal(t, int64(42), *repo.gens[1].APIKeyID)
+}
+
+func TestGenerate_NoAPIKeyID_NoContext_NilResult(t *testing.T) {
+ // 既无 api_key_id 字段也无 context 中的 api_key_id → api_key_id 为 nil
+ repo := newStubSoraGenRepo()
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+ apiKeyRepo := newStubAPIKeyRepoForHandler()
+ apiKeyService := newTestAPIKeyService(apiKeyRepo)
+
+ h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
+ c, rec := makeGinContext("POST", "/api/v1/sora/generate",
+ `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
+ h.Generate(c)
+ require.Equal(t, http.StatusOK, rec.Code)
+ require.Nil(t, repo.gens[1].APIKeyID)
+}
+
+func TestGenerate_WithAPIKeyIDInBody_OverridesContext(t *testing.T) {
+ // 同时有 body api_key_id 和 context api_key_id → 优先使用 body 的
+ repo := newStubSoraGenRepo()
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+
+ groupID := int64(10)
+ apiKeyRepo := newStubAPIKeyRepoForHandler()
+ apiKeyRepo.keys[42] = &service.APIKey{
+ ID: 42,
+ UserID: 1,
+ Status: service.StatusAPIKeyActive,
+ GroupID: &groupID,
+ }
+ apiKeyService := newTestAPIKeyService(apiKeyRepo)
+
+ h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
+ c, rec := makeGinContext("POST", "/api/v1/sora/generate",
+ `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1)
+ c.Set("api_key_id", int64(99)) // context 中有另一个 api_key_id
+ h.Generate(c)
+ require.Equal(t, http.StatusOK, rec.Code)
+ // 应使用 body 中的 api_key_id=42,而不是 context 中的 99
+ require.NotNil(t, repo.gens[1].APIKeyID)
+ require.Equal(t, int64(42), *repo.gens[1].APIKeyID)
+}
+
+func TestGenerate_WithContextAPIKeyID_FallbackPath(t *testing.T) {
+ // 无 body api_key_id,但 context 有 → 使用 context 中的(兼容网关路由)
+ repo := newStubSoraGenRepo()
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+ apiKeyRepo := newStubAPIKeyRepoForHandler()
+ apiKeyService := newTestAPIKeyService(apiKeyRepo)
+
+ h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
+ c, rec := makeGinContext("POST", "/api/v1/sora/generate",
+ `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
+ c.Set("api_key_id", int64(99))
+ h.Generate(c)
+ require.Equal(t, http.StatusOK, rec.Code)
+ // 应使用 context 中的 api_key_id=99
+ require.NotNil(t, repo.gens[1].APIKeyID)
+ require.Equal(t, int64(99), *repo.gens[1].APIKeyID)
+}
+
+func TestGenerate_APIKeyID_Zero_IgnoredInJSON(t *testing.T) {
+ // JSON 中 api_key_id=0 被视为 omitempty → 仍然为指针值 0,需要传 nil 检查
+ repo := newStubSoraGenRepo()
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+ apiKeyRepo := newStubAPIKeyRepoForHandler()
+ apiKeyService := newTestAPIKeyService(apiKeyRepo)
+
+ h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
+ // JSON 中传了 api_key_id: 0 → 解析后 *int64(0),会触发校验
+ // api_key_id=0 不存在 → 400
+ c, rec := makeGinContext("POST", "/api/v1/sora/generate",
+ `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":0}`, 1)
+ h.Generate(c)
+ require.Equal(t, http.StatusBadRequest, rec.Code)
+}
+
+// ==================== processGeneration: groupID 传递与 ForcePlatform ====================
+
+func TestProcessGeneration_WithGroupID_NoForcePlatform(t *testing.T) {
+ // groupID 不为 nil → 不设置 ForcePlatform
+ // gatewayService 为 nil → MarkFailed → 检查错误消息不包含 ForcePlatform 相关
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+ h := &SoraClientHandler{genService: genService}
+
+ gid := int64(5)
+ h.processGeneration(1, 1, &gid, "sora2-landscape-10s", "test", "video", "", 1)
+ require.Equal(t, "failed", repo.gens[1].Status)
+ require.Contains(t, repo.gens[1].ErrorMessage, "gatewayService")
+}
+
+func TestProcessGeneration_NilGroupID_SetsForcePlatform(t *testing.T) {
+ // groupID 为 nil → 设置 ForcePlatform → gatewayService 为 nil → MarkFailed
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+ h := &SoraClientHandler{genService: genService}
+
+ h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1)
+ require.Equal(t, "failed", repo.gens[1].Status)
+ require.Contains(t, repo.gens[1].ErrorMessage, "gatewayService")
+}
+
+func TestProcessGeneration_MarkGeneratingStateConflict(t *testing.T) {
+ // 任务状态已变化(如已取消)→ MarkGenerating 返回 ErrSoraGenerationStateConflict → 跳过
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "cancelled"}
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+ h := &SoraClientHandler{genService: genService}
+
+ h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1)
+ // 状态为 cancelled 时 MarkGenerating 不符合状态转换规则 → 应保持 cancelled
+ require.Equal(t, "cancelled", repo.gens[1].Status)
+}
+
+// ==================== GenerateRequest JSON 解析 ====================
+
+func TestGenerateRequest_WithAPIKeyID_JSONParsing(t *testing.T) {
+ // 验证 api_key_id 在 JSON 中正确解析为 *int64
+ var req GenerateRequest
+ err := json.Unmarshal([]byte(`{"model":"sora2","prompt":"test","api_key_id":42}`), &req)
+ require.NoError(t, err)
+ require.NotNil(t, req.APIKeyID)
+ require.Equal(t, int64(42), *req.APIKeyID)
+}
+
+func TestGenerateRequest_WithoutAPIKeyID_JSONParsing(t *testing.T) {
+ // 不传 api_key_id → 解析后为 nil
+ var req GenerateRequest
+ err := json.Unmarshal([]byte(`{"model":"sora2","prompt":"test"}`), &req)
+ require.NoError(t, err)
+ require.Nil(t, req.APIKeyID)
+}
+
+func TestGenerateRequest_NullAPIKeyID_JSONParsing(t *testing.T) {
+ // api_key_id: null → 解析后为 nil
+ var req GenerateRequest
+ err := json.Unmarshal([]byte(`{"model":"sora2","prompt":"test","api_key_id":null}`), &req)
+ require.NoError(t, err)
+ require.Nil(t, req.APIKeyID)
+}
+
+func TestGenerateRequest_FullFields_JSONParsing(t *testing.T) {
+ // 全字段解析
+ var req GenerateRequest
+ err := json.Unmarshal([]byte(`{
+ "model":"sora2-landscape-10s",
+ "prompt":"test prompt",
+ "media_type":"video",
+ "video_count":2,
+ "image_input":"data:image/png;base64,abc",
+ "api_key_id":100
+ }`), &req)
+ require.NoError(t, err)
+ require.Equal(t, "sora2-landscape-10s", req.Model)
+ require.Equal(t, "test prompt", req.Prompt)
+ require.Equal(t, "video", req.MediaType)
+ require.Equal(t, 2, req.VideoCount)
+ require.Equal(t, "data:image/png;base64,abc", req.ImageInput)
+ require.NotNil(t, req.APIKeyID)
+ require.Equal(t, int64(100), *req.APIKeyID)
+}
+
+func TestGenerateRequest_JSONSerialize_OmitsNilAPIKeyID(t *testing.T) {
+ // api_key_id 为 nil 时 JSON 序列化应省略
+ req := GenerateRequest{Model: "sora2", Prompt: "test"}
+ b, err := json.Marshal(req)
+ require.NoError(t, err)
+ var parsed map[string]any
+ require.NoError(t, json.Unmarshal(b, &parsed))
+ _, hasAPIKeyID := parsed["api_key_id"]
+ require.False(t, hasAPIKeyID, "api_key_id 为 nil 时应省略")
+}
+
+func TestGenerateRequest_JSONSerialize_IncludesAPIKeyID(t *testing.T) {
+ // api_key_id 不为 nil 时 JSON 序列化应包含
+ id := int64(42)
+ req := GenerateRequest{Model: "sora2", Prompt: "test", APIKeyID: &id}
+ b, err := json.Marshal(req)
+ require.NoError(t, err)
+ var parsed map[string]any
+ require.NoError(t, json.Unmarshal(b, &parsed))
+ require.Equal(t, float64(42), parsed["api_key_id"])
+}
+
+// ==================== GetQuota: 有配额服务 ====================
+
+func TestGetQuota_WithQuotaService_Success(t *testing.T) {
+ userRepo := newStubUserRepoForHandler()
+ userRepo.users[1] = &service.User{
+ ID: 1,
+ SoraStorageQuotaBytes: 10 * 1024 * 1024,
+ SoraStorageUsedBytes: 3 * 1024 * 1024,
+ }
+ quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
+
+ repo := newStubSoraGenRepo()
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+ h := &SoraClientHandler{
+ genService: genService,
+ quotaService: quotaService,
+ }
+
+ c, rec := makeGinContext("GET", "/api/v1/sora/quota", "", 1)
+ h.GetQuota(c)
+ require.Equal(t, http.StatusOK, rec.Code)
+ resp := parseResponse(t, rec)
+ data := resp["data"].(map[string]any)
+ require.Equal(t, "user", data["source"])
+ require.Equal(t, float64(10*1024*1024), data["quota_bytes"])
+ require.Equal(t, float64(3*1024*1024), data["used_bytes"])
+}
+
+func TestGetQuota_WithQuotaService_Error(t *testing.T) {
+ // 用户不存在时 GetQuota 返回错误
+ userRepo := newStubUserRepoForHandler()
+ quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
+
+ repo := newStubSoraGenRepo()
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+ h := &SoraClientHandler{
+ genService: genService,
+ quotaService: quotaService,
+ }
+
+ c, rec := makeGinContext("GET", "/api/v1/sora/quota", "", 999)
+ h.GetQuota(c)
+ require.Equal(t, http.StatusInternalServerError, rec.Code)
+}
+
+// ==================== Generate: 配额检查 ====================
+
+func TestGenerate_QuotaCheckFailed(t *testing.T) {
+ // 配额超限时返回 429
+ userRepo := newStubUserRepoForHandler()
+ userRepo.users[1] = &service.User{
+ ID: 1,
+ SoraStorageQuotaBytes: 1024,
+ SoraStorageUsedBytes: 1025, // 已超限
+ }
+ quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
+
+ repo := newStubSoraGenRepo()
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+ h := &SoraClientHandler{
+ genService: genService,
+ quotaService: quotaService,
+ }
+
+ c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
+ h.Generate(c)
+ require.Equal(t, http.StatusTooManyRequests, rec.Code)
+}
+
+func TestGenerate_QuotaCheckPassed(t *testing.T) {
+ // 配额充足时允许生成
+ userRepo := newStubUserRepoForHandler()
+ userRepo.users[1] = &service.User{
+ ID: 1,
+ SoraStorageQuotaBytes: 10 * 1024 * 1024,
+ SoraStorageUsedBytes: 0,
+ }
+ quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
+
+ repo := newStubSoraGenRepo()
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+ h := &SoraClientHandler{
+ genService: genService,
+ quotaService: quotaService,
+ }
+
+ c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
+ h.Generate(c)
+ require.Equal(t, http.StatusOK, rec.Code)
+}
+
+// ==================== Stub: SettingRepository (用于 S3 存储测试) ====================
+
+var _ service.SettingRepository = (*stubSettingRepoForHandler)(nil)
+
+type stubSettingRepoForHandler struct {
+ values map[string]string
+}
+
+func newStubSettingRepoForHandler(values map[string]string) *stubSettingRepoForHandler {
+ if values == nil {
+ values = make(map[string]string)
+ }
+ return &stubSettingRepoForHandler{values: values}
+}
+
+func (r *stubSettingRepoForHandler) Get(_ context.Context, key string) (*service.Setting, error) {
+ if v, ok := r.values[key]; ok {
+ return &service.Setting{Key: key, Value: v}, nil
+ }
+ return nil, service.ErrSettingNotFound
+}
+func (r *stubSettingRepoForHandler) GetValue(_ context.Context, key string) (string, error) {
+ if v, ok := r.values[key]; ok {
+ return v, nil
+ }
+ return "", service.ErrSettingNotFound
+}
+func (r *stubSettingRepoForHandler) Set(_ context.Context, key, value string) error {
+ r.values[key] = value
+ return nil
+}
+func (r *stubSettingRepoForHandler) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
+ result := make(map[string]string)
+ for _, k := range keys {
+ if v, ok := r.values[k]; ok {
+ result[k] = v
+ }
+ }
+ return result, nil
+}
+func (r *stubSettingRepoForHandler) SetMultiple(_ context.Context, settings map[string]string) error {
+ for k, v := range settings {
+ r.values[k] = v
+ }
+ return nil
+}
+func (r *stubSettingRepoForHandler) GetAll(_ context.Context) (map[string]string, error) {
+ return r.values, nil
+}
+func (r *stubSettingRepoForHandler) Delete(_ context.Context, key string) error {
+ delete(r.values, key)
+ return nil
+}
+
+// ==================== S3 / MediaStorage 辅助函数 ====================
+
+// newS3StorageForHandler 创建指向指定 endpoint 的 S3Storage(用于测试)。
+func newS3StorageForHandler(endpoint string) *service.SoraS3Storage {
+ settingRepo := newStubSettingRepoForHandler(map[string]string{
+ "sora_s3_enabled": "true",
+ "sora_s3_endpoint": endpoint,
+ "sora_s3_region": "us-east-1",
+ "sora_s3_bucket": "test-bucket",
+ "sora_s3_access_key_id": "AKIATEST",
+ "sora_s3_secret_access_key": "test-secret",
+ "sora_s3_prefix": "sora",
+ "sora_s3_force_path_style": "true",
+ })
+ settingService := service.NewSettingService(settingRepo, &config.Config{})
+ return service.NewSoraS3Storage(settingService)
+}
+
+// newFakeSourceServer 创建返回固定内容的 HTTP 服务器(模拟上游媒体文件)。
+func newFakeSourceServer() *httptest.Server {
+ return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "video/mp4")
+ w.WriteHeader(http.StatusOK)
+ _, _ = w.Write([]byte("fake video data for test"))
+ }))
+}
+
+// newFakeS3Server 创建模拟 S3 的 HTTP 服务器。
+// mode: "ok" 接受所有请求,"fail" 返回 403,"fail-second" 第一次成功第二次失败。
+func newFakeS3Server(mode string) *httptest.Server {
+ var counter atomic.Int32
+ return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ _, _ = io.Copy(io.Discard, r.Body)
+ _ = r.Body.Close()
+
+ switch mode {
+ case "ok":
+ w.Header().Set("ETag", `"test-etag"`)
+ w.WriteHeader(http.StatusOK)
+ case "fail":
+ w.WriteHeader(http.StatusForbidden)
+ _, _ = w.Write([]byte(`AccessDenied`))
+ case "fail-second":
+ n := counter.Add(1)
+ if n <= 1 {
+ w.Header().Set("ETag", `"test-etag"`)
+ w.WriteHeader(http.StatusOK)
+ } else {
+ w.WriteHeader(http.StatusForbidden)
+ _, _ = w.Write([]byte(`AccessDenied`))
+ }
+ }
+ }))
+}
+
+// ==================== processGeneration 直接调用测试 ====================
+
+func TestProcessGeneration_MarkGeneratingFails(t *testing.T) {
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
+ repo.updateErr = fmt.Errorf("db error")
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+ h := &SoraClientHandler{genService: genService}
+
+ // 直接调用(非 goroutine),MarkGenerating 失败 → 早退
+ h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1)
+ // MarkGenerating 在调用 repo.Update 前已修改内存对象为 "generating"
+ // repo.Update 返回错误 → processGeneration 早退,不会继续到 MarkFailed
+ // 因此 ErrorMessage 为空(证明未调用 MarkFailed)
+ require.Equal(t, "generating", repo.gens[1].Status)
+ require.Empty(t, repo.gens[1].ErrorMessage)
+}
+
+func TestProcessGeneration_GatewayServiceNil(t *testing.T) {
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+ h := &SoraClientHandler{genService: genService}
+ // gatewayService 未设置 → MarkFailed
+
+ h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1)
+ require.Equal(t, "failed", repo.gens[1].Status)
+ require.Contains(t, repo.gens[1].ErrorMessage, "gatewayService")
+}
+
+// ==================== storeMediaWithDegradation: S3 路径 ====================
+
+func TestStoreMediaWithDegradation_S3SuccessSingleURL(t *testing.T) {
+ sourceServer := newFakeSourceServer()
+ defer sourceServer.Close()
+ fakeS3 := newFakeS3Server("ok")
+ defer fakeS3.Close()
+
+ s3Storage := newS3StorageForHandler(fakeS3.URL)
+ h := &SoraClientHandler{s3Storage: s3Storage}
+
+ storedURL, storedURLs, storageType, s3Keys, fileSize := h.storeMediaWithDegradation(
+ context.Background(), 1, "video", sourceServer.URL+"/v.mp4", nil,
+ )
+ require.Equal(t, service.SoraStorageTypeS3, storageType)
+ require.Len(t, s3Keys, 1)
+ require.NotEmpty(t, s3Keys[0])
+ require.Len(t, storedURLs, 1)
+ require.Equal(t, storedURL, storedURLs[0])
+ require.Contains(t, storedURL, fakeS3.URL)
+ require.Contains(t, storedURL, "/test-bucket/")
+ require.Greater(t, fileSize, int64(0))
+}
+
+func TestStoreMediaWithDegradation_S3SuccessMultiURL(t *testing.T) {
+ sourceServer := newFakeSourceServer()
+ defer sourceServer.Close()
+ fakeS3 := newFakeS3Server("ok")
+ defer fakeS3.Close()
+
+ s3Storage := newS3StorageForHandler(fakeS3.URL)
+ h := &SoraClientHandler{s3Storage: s3Storage}
+
+ urls := []string{sourceServer.URL + "/a.mp4", sourceServer.URL + "/b.mp4"}
+ storedURL, storedURLs, storageType, s3Keys, fileSize := h.storeMediaWithDegradation(
+ context.Background(), 1, "video", sourceServer.URL+"/a.mp4", urls,
+ )
+ require.Equal(t, service.SoraStorageTypeS3, storageType)
+ require.Len(t, s3Keys, 2)
+ require.Len(t, storedURLs, 2)
+ require.Equal(t, storedURL, storedURLs[0])
+ require.Contains(t, storedURLs[0], fakeS3.URL)
+ require.Contains(t, storedURLs[1], fakeS3.URL)
+ require.Greater(t, fileSize, int64(0))
+}
+
+func TestStoreMediaWithDegradation_S3DownloadFails(t *testing.T) {
+ // 上游返回 404 → 下载失败 → S3 上传不会开始
+ fakeS3 := newFakeS3Server("ok")
+ defer fakeS3.Close()
+ badSource := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusNotFound)
+ }))
+ defer badSource.Close()
+
+ s3Storage := newS3StorageForHandler(fakeS3.URL)
+ h := &SoraClientHandler{s3Storage: s3Storage}
+
+ _, _, storageType, _, _ := h.storeMediaWithDegradation(
+ context.Background(), 1, "video", badSource.URL+"/missing.mp4", nil,
+ )
+ require.Equal(t, service.SoraStorageTypeUpstream, storageType)
+}
+
+func TestStoreMediaWithDegradation_S3FailsSingleURL(t *testing.T) {
+ sourceServer := newFakeSourceServer()
+ defer sourceServer.Close()
+ fakeS3 := newFakeS3Server("fail")
+ defer fakeS3.Close()
+
+ s3Storage := newS3StorageForHandler(fakeS3.URL)
+ h := &SoraClientHandler{s3Storage: s3Storage}
+
+ _, _, storageType, s3Keys, _ := h.storeMediaWithDegradation(
+ context.Background(), 1, "video", sourceServer.URL+"/v.mp4", nil,
+ )
+ // S3 失败,降级到 upstream
+ require.Equal(t, service.SoraStorageTypeUpstream, storageType)
+ require.Nil(t, s3Keys)
+}
+
+func TestStoreMediaWithDegradation_S3PartialFailureCleanup(t *testing.T) {
+ sourceServer := newFakeSourceServer()
+ defer sourceServer.Close()
+ fakeS3 := newFakeS3Server("fail-second")
+ defer fakeS3.Close()
+
+ s3Storage := newS3StorageForHandler(fakeS3.URL)
+ h := &SoraClientHandler{s3Storage: s3Storage}
+
+ urls := []string{sourceServer.URL + "/a.mp4", sourceServer.URL + "/b.mp4"}
+ _, _, storageType, s3Keys, _ := h.storeMediaWithDegradation(
+ context.Background(), 1, "video", sourceServer.URL+"/a.mp4", urls,
+ )
+ // 第二个 URL 上传失败 → 清理已上传 → 降级到 upstream
+ require.Equal(t, service.SoraStorageTypeUpstream, storageType)
+ require.Nil(t, s3Keys)
+}
+
+// ==================== storeMediaWithDegradation: 本地存储路径 ====================
+
+func TestStoreMediaWithDegradation_LocalStorageFails(t *testing.T) {
+ // 使用无效路径,EnsureLocalDirs 失败 → StoreFromURLs 返回 error
+ cfg := &config.Config{
+ Sora: config.SoraConfig{
+ Storage: config.SoraStorageConfig{
+ Type: "local",
+ LocalPath: "/dev/null/invalid_dir",
+ },
+ },
+ }
+ mediaStorage := service.NewSoraMediaStorage(cfg)
+ h := &SoraClientHandler{mediaStorage: mediaStorage}
+
+ _, _, storageType, _, _ := h.storeMediaWithDegradation(
+ context.Background(), 1, "video", "https://upstream.com/v.mp4", nil,
+ )
+ // 本地存储失败,降级到 upstream
+ require.Equal(t, service.SoraStorageTypeUpstream, storageType)
+}
+
+func TestStoreMediaWithDegradation_LocalStorageSuccess(t *testing.T) {
+ tmpDir, err := os.MkdirTemp("", "sora-handler-test-*")
+ require.NoError(t, err)
+ defer os.RemoveAll(tmpDir)
+
+ sourceServer := newFakeSourceServer()
+ defer sourceServer.Close()
+
+ cfg := &config.Config{
+ Sora: config.SoraConfig{
+ Storage: config.SoraStorageConfig{
+ Type: "local",
+ LocalPath: tmpDir,
+ DownloadTimeoutSeconds: 5,
+ MaxDownloadBytes: 10 * 1024 * 1024,
+ },
+ },
+ }
+ mediaStorage := service.NewSoraMediaStorage(cfg)
+ h := &SoraClientHandler{mediaStorage: mediaStorage}
+
+ _, _, storageType, s3Keys, _ := h.storeMediaWithDegradation(
+ context.Background(), 1, "video", sourceServer.URL+"/v.mp4", nil,
+ )
+ require.Equal(t, service.SoraStorageTypeLocal, storageType)
+ require.Nil(t, s3Keys) // 本地存储不返回 S3 keys
+}
+
+func TestStoreMediaWithDegradation_S3FailsFallbackToLocal(t *testing.T) {
+ tmpDir, err := os.MkdirTemp("", "sora-handler-test-*")
+ require.NoError(t, err)
+ defer os.RemoveAll(tmpDir)
+
+ sourceServer := newFakeSourceServer()
+ defer sourceServer.Close()
+ fakeS3 := newFakeS3Server("fail")
+ defer fakeS3.Close()
+
+ s3Storage := newS3StorageForHandler(fakeS3.URL)
+ cfg := &config.Config{
+ Sora: config.SoraConfig{
+ Storage: config.SoraStorageConfig{
+ Type: "local",
+ LocalPath: tmpDir,
+ DownloadTimeoutSeconds: 5,
+ MaxDownloadBytes: 10 * 1024 * 1024,
+ },
+ },
+ }
+ mediaStorage := service.NewSoraMediaStorage(cfg)
+ h := &SoraClientHandler{
+ s3Storage: s3Storage,
+ mediaStorage: mediaStorage,
+ }
+
+ _, _, storageType, _, _ := h.storeMediaWithDegradation(
+ context.Background(), 1, "video", sourceServer.URL+"/v.mp4", nil,
+ )
+ // S3 失败 → 本地存储成功
+ require.Equal(t, service.SoraStorageTypeLocal, storageType)
+}
+
+// ==================== SaveToStorage: S3 路径 ====================
+
+func TestSaveToStorage_S3EnabledButUploadFails(t *testing.T) {
+ sourceServer := newFakeSourceServer()
+ defer sourceServer.Close()
+ fakeS3 := newFakeS3Server("fail")
+ defer fakeS3.Close()
+
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{
+ ID: 1, UserID: 1, Status: "completed",
+ StorageType: "upstream",
+ MediaURL: sourceServer.URL + "/v.mp4",
+ }
+ s3Storage := newS3StorageForHandler(fakeS3.URL)
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+ h := &SoraClientHandler{genService: genService, s3Storage: s3Storage}
+
+ c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
+ c.Params = gin.Params{{Key: "id", Value: "1"}}
+ h.SaveToStorage(c)
+ require.Equal(t, http.StatusInternalServerError, rec.Code)
+ resp := parseResponse(t, rec)
+ require.Contains(t, resp["message"], "S3")
+}
+
+func TestSaveToStorage_UpstreamURLExpired(t *testing.T) {
+ expiredServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ w.WriteHeader(http.StatusForbidden)
+ }))
+ defer expiredServer.Close()
+ fakeS3 := newFakeS3Server("ok")
+ defer fakeS3.Close()
+
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{
+ ID: 1, UserID: 1, Status: "completed",
+ StorageType: "upstream",
+ MediaURL: expiredServer.URL + "/v.mp4",
+ }
+ s3Storage := newS3StorageForHandler(fakeS3.URL)
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+ h := &SoraClientHandler{genService: genService, s3Storage: s3Storage}
+
+ c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
+ c.Params = gin.Params{{Key: "id", Value: "1"}}
+ h.SaveToStorage(c)
+ require.Equal(t, http.StatusGone, rec.Code)
+ resp := parseResponse(t, rec)
+ require.Contains(t, fmt.Sprint(resp["message"]), "过期")
+}
+
+func TestSaveToStorage_S3EnabledUploadSuccess(t *testing.T) {
+ sourceServer := newFakeSourceServer()
+ defer sourceServer.Close()
+ fakeS3 := newFakeS3Server("ok")
+ defer fakeS3.Close()
+
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{
+ ID: 1, UserID: 1, Status: "completed",
+ StorageType: "upstream",
+ MediaURL: sourceServer.URL + "/v.mp4",
+ }
+ s3Storage := newS3StorageForHandler(fakeS3.URL)
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+ h := &SoraClientHandler{genService: genService, s3Storage: s3Storage}
+
+ c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
+ c.Params = gin.Params{{Key: "id", Value: "1"}}
+ h.SaveToStorage(c)
+ require.Equal(t, http.StatusOK, rec.Code)
+ resp := parseResponse(t, rec)
+ data := resp["data"].(map[string]any)
+ require.Contains(t, data["message"], "S3")
+ require.NotEmpty(t, data["object_key"])
+ // 验证记录已更新为 S3 存储
+ require.Equal(t, service.SoraStorageTypeS3, repo.gens[1].StorageType)
+}
+
+func TestSaveToStorage_S3EnabledUploadSuccess_MultiMediaURLs(t *testing.T) {
+ sourceServer := newFakeSourceServer()
+ defer sourceServer.Close()
+ fakeS3 := newFakeS3Server("ok")
+ defer fakeS3.Close()
+
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{
+ ID: 1, UserID: 1, Status: "completed",
+ StorageType: "upstream",
+ MediaURL: sourceServer.URL + "/v1.mp4",
+ MediaURLs: []string{
+ sourceServer.URL + "/v1.mp4",
+ sourceServer.URL + "/v2.mp4",
+ },
+ }
+ s3Storage := newS3StorageForHandler(fakeS3.URL)
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+ h := &SoraClientHandler{genService: genService, s3Storage: s3Storage}
+
+ c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
+ c.Params = gin.Params{{Key: "id", Value: "1"}}
+ h.SaveToStorage(c)
+ require.Equal(t, http.StatusOK, rec.Code)
+ resp := parseResponse(t, rec)
+ data := resp["data"].(map[string]any)
+ require.Len(t, data["object_keys"].([]any), 2)
+ require.Equal(t, service.SoraStorageTypeS3, repo.gens[1].StorageType)
+ require.Len(t, repo.gens[1].S3ObjectKeys, 2)
+ require.Len(t, repo.gens[1].MediaURLs, 2)
+}
+
+func TestSaveToStorage_S3EnabledUploadSuccessWithQuota(t *testing.T) {
+ sourceServer := newFakeSourceServer()
+ defer sourceServer.Close()
+ fakeS3 := newFakeS3Server("ok")
+ defer fakeS3.Close()
+
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{
+ ID: 1, UserID: 1, Status: "completed",
+ StorageType: "upstream",
+ MediaURL: sourceServer.URL + "/v.mp4",
+ }
+ s3Storage := newS3StorageForHandler(fakeS3.URL)
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+
+ userRepo := newStubUserRepoForHandler()
+ userRepo.users[1] = &service.User{
+ ID: 1,
+ SoraStorageQuotaBytes: 100 * 1024 * 1024,
+ SoraStorageUsedBytes: 0,
+ }
+ quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
+ h := &SoraClientHandler{genService: genService, s3Storage: s3Storage, quotaService: quotaService}
+
+ c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
+ c.Params = gin.Params{{Key: "id", Value: "1"}}
+ h.SaveToStorage(c)
+ require.Equal(t, http.StatusOK, rec.Code)
+ // 验证配额已累加
+ require.Greater(t, userRepo.users[1].SoraStorageUsedBytes, int64(0))
+}
+
+func TestSaveToStorage_S3UploadSuccessMarkCompletedFails(t *testing.T) {
+ sourceServer := newFakeSourceServer()
+ defer sourceServer.Close()
+ fakeS3 := newFakeS3Server("ok")
+ defer fakeS3.Close()
+
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{
+ ID: 1, UserID: 1, Status: "completed",
+ StorageType: "upstream",
+ MediaURL: sourceServer.URL + "/v.mp4",
+ }
+ // S3 上传成功后,MarkCompleted 会调用 repo.Update → 失败
+ repo.updateErr = fmt.Errorf("db error")
+ s3Storage := newS3StorageForHandler(fakeS3.URL)
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+ h := &SoraClientHandler{genService: genService, s3Storage: s3Storage}
+
+ c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
+ c.Params = gin.Params{{Key: "id", Value: "1"}}
+ h.SaveToStorage(c)
+ require.Equal(t, http.StatusInternalServerError, rec.Code)
+}
+
+// ==================== GetStorageStatus: S3 路径 ====================
+
+func TestGetStorageStatus_S3EnabledNotHealthy(t *testing.T) {
+ // S3 启用但 TestConnection 失败(fake 端点不响应 HeadBucket)
+ fakeS3 := newFakeS3Server("fail")
+ defer fakeS3.Close()
+
+ s3Storage := newS3StorageForHandler(fakeS3.URL)
+ h := &SoraClientHandler{s3Storage: s3Storage}
+
+ c, rec := makeGinContext("GET", "/api/v1/sora/storage-status", "", 0)
+ h.GetStorageStatus(c)
+ require.Equal(t, http.StatusOK, rec.Code)
+ resp := parseResponse(t, rec)
+ data := resp["data"].(map[string]any)
+ require.Equal(t, true, data["s3_enabled"])
+ require.Equal(t, false, data["s3_healthy"])
+}
+
+func TestGetStorageStatus_S3EnabledHealthy(t *testing.T) {
+ fakeS3 := newFakeS3Server("ok")
+ defer fakeS3.Close()
+
+ s3Storage := newS3StorageForHandler(fakeS3.URL)
+ h := &SoraClientHandler{s3Storage: s3Storage}
+
+ c, rec := makeGinContext("GET", "/api/v1/sora/storage-status", "", 0)
+ h.GetStorageStatus(c)
+ require.Equal(t, http.StatusOK, rec.Code)
+ resp := parseResponse(t, rec)
+ data := resp["data"].(map[string]any)
+ require.Equal(t, true, data["s3_enabled"])
+ require.Equal(t, true, data["s3_healthy"])
+}
+
+// ==================== Stub: AccountRepository (用于 GatewayService) ====================
+
+var _ service.AccountRepository = (*stubAccountRepoForHandler)(nil)
+
+type stubAccountRepoForHandler struct {
+ accounts []service.Account
+}
+
+func (r *stubAccountRepoForHandler) Create(context.Context, *service.Account) error { return nil }
+func (r *stubAccountRepoForHandler) GetByID(_ context.Context, id int64) (*service.Account, error) {
+ for i := range r.accounts {
+ if r.accounts[i].ID == id {
+ return &r.accounts[i], nil
+ }
+ }
+ return nil, fmt.Errorf("account not found")
+}
+func (r *stubAccountRepoForHandler) GetByIDs(context.Context, []int64) ([]*service.Account, error) {
+ return nil, nil
+}
+func (r *stubAccountRepoForHandler) ExistsByID(context.Context, int64) (bool, error) {
+ return false, nil
+}
+func (r *stubAccountRepoForHandler) GetByCRSAccountID(context.Context, string) (*service.Account, error) {
+ return nil, nil
+}
+func (r *stubAccountRepoForHandler) FindByExtraField(context.Context, string, any) ([]service.Account, error) {
+ return nil, nil
+}
+func (r *stubAccountRepoForHandler) ListCRSAccountIDs(context.Context) (map[string]int64, error) {
+ return nil, nil
+}
+func (r *stubAccountRepoForHandler) Update(context.Context, *service.Account) error { return nil }
+func (r *stubAccountRepoForHandler) Delete(context.Context, int64) error { return nil }
+func (r *stubAccountRepoForHandler) List(context.Context, pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) {
+ return nil, nil, nil
+}
+func (r *stubAccountRepoForHandler) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, string, int64) ([]service.Account, *pagination.PaginationResult, error) {
+ return nil, nil, nil
+}
+func (r *stubAccountRepoForHandler) ListByGroup(context.Context, int64) ([]service.Account, error) {
+ return nil, nil
+}
+func (r *stubAccountRepoForHandler) ListActive(context.Context) ([]service.Account, error) {
+ return nil, nil
+}
+func (r *stubAccountRepoForHandler) ListByPlatform(context.Context, string) ([]service.Account, error) {
+ return nil, nil
+}
+func (r *stubAccountRepoForHandler) UpdateLastUsed(context.Context, int64) error { return nil }
+func (r *stubAccountRepoForHandler) BatchUpdateLastUsed(context.Context, map[int64]time.Time) error {
+ return nil
+}
+func (r *stubAccountRepoForHandler) SetError(context.Context, int64, string) error { return nil }
+func (r *stubAccountRepoForHandler) ClearError(context.Context, int64) error { return nil }
+func (r *stubAccountRepoForHandler) SetSchedulable(context.Context, int64, bool) error {
+ return nil
+}
+func (r *stubAccountRepoForHandler) AutoPauseExpiredAccounts(context.Context, time.Time) (int64, error) {
+ return 0, nil
+}
+func (r *stubAccountRepoForHandler) BindGroups(context.Context, int64, []int64) error { return nil }
+func (r *stubAccountRepoForHandler) ListSchedulable(context.Context) ([]service.Account, error) {
+ return r.accounts, nil
+}
+func (r *stubAccountRepoForHandler) ListSchedulableByGroupID(context.Context, int64) ([]service.Account, error) {
+ return r.accounts, nil
+}
+func (r *stubAccountRepoForHandler) ListSchedulableByPlatform(_ context.Context, _ string) ([]service.Account, error) {
+ return r.accounts, nil
+}
+func (r *stubAccountRepoForHandler) ListSchedulableByGroupIDAndPlatform(context.Context, int64, string) ([]service.Account, error) {
+ return r.accounts, nil
+}
+func (r *stubAccountRepoForHandler) ListSchedulableByPlatforms(context.Context, []string) ([]service.Account, error) {
+ return r.accounts, nil
+}
+func (r *stubAccountRepoForHandler) ListSchedulableByGroupIDAndPlatforms(context.Context, int64, []string) ([]service.Account, error) {
+ return r.accounts, nil
+}
+func (r *stubAccountRepoForHandler) SetRateLimited(context.Context, int64, time.Time) error {
+ return nil
+}
+func (r *stubAccountRepoForHandler) SetModelRateLimit(context.Context, int64, string, time.Time) error {
+ return nil
+}
+func (r *stubAccountRepoForHandler) SetOverloaded(context.Context, int64, time.Time) error {
+ return nil
+}
+func (r *stubAccountRepoForHandler) SetTempUnschedulable(context.Context, int64, time.Time, string) error {
+ return nil
+}
+func (r *stubAccountRepoForHandler) ClearTempUnschedulable(context.Context, int64) error { return nil }
+func (r *stubAccountRepoForHandler) ClearRateLimit(context.Context, int64) error { return nil }
+func (r *stubAccountRepoForHandler) ClearAntigravityQuotaScopes(context.Context, int64) error {
+ return nil
+}
+func (r *stubAccountRepoForHandler) ClearModelRateLimits(context.Context, int64) error { return nil }
+func (r *stubAccountRepoForHandler) UpdateSessionWindow(context.Context, int64, *time.Time, *time.Time, string) error {
+ return nil
+}
+func (r *stubAccountRepoForHandler) UpdateExtra(context.Context, int64, map[string]any) error {
+ return nil
+}
+func (r *stubAccountRepoForHandler) BulkUpdate(context.Context, []int64, service.AccountBulkUpdate) (int64, error) {
+ return 0, nil
+}
+
+// ==================== Stub: SoraClient (用于 SoraGatewayService) ====================
+
+var _ service.SoraClient = (*stubSoraClientForHandler)(nil)
+
+type stubSoraClientForHandler struct {
+ videoStatus *service.SoraVideoTaskStatus
+}
+
+func (s *stubSoraClientForHandler) Enabled() bool { return true }
+func (s *stubSoraClientForHandler) UploadImage(context.Context, *service.Account, []byte, string) (string, error) {
+ return "", nil
+}
+func (s *stubSoraClientForHandler) CreateImageTask(context.Context, *service.Account, service.SoraImageRequest) (string, error) {
+ return "task-image", nil
+}
+func (s *stubSoraClientForHandler) CreateVideoTask(context.Context, *service.Account, service.SoraVideoRequest) (string, error) {
+ return "task-video", nil
+}
+func (s *stubSoraClientForHandler) CreateStoryboardTask(context.Context, *service.Account, service.SoraStoryboardRequest) (string, error) {
+ return "task-video", nil
+}
+func (s *stubSoraClientForHandler) UploadCharacterVideo(context.Context, *service.Account, []byte) (string, error) {
+ return "", nil
+}
+func (s *stubSoraClientForHandler) GetCameoStatus(context.Context, *service.Account, string) (*service.SoraCameoStatus, error) {
+ return nil, nil
+}
+func (s *stubSoraClientForHandler) DownloadCharacterImage(context.Context, *service.Account, string) ([]byte, error) {
+ return nil, nil
+}
+func (s *stubSoraClientForHandler) UploadCharacterImage(context.Context, *service.Account, []byte) (string, error) {
+ return "", nil
+}
+func (s *stubSoraClientForHandler) FinalizeCharacter(context.Context, *service.Account, service.SoraCharacterFinalizeRequest) (string, error) {
+ return "", nil
+}
+func (s *stubSoraClientForHandler) SetCharacterPublic(context.Context, *service.Account, string) error {
+ return nil
+}
+func (s *stubSoraClientForHandler) DeleteCharacter(context.Context, *service.Account, string) error {
+ return nil
+}
+func (s *stubSoraClientForHandler) PostVideoForWatermarkFree(context.Context, *service.Account, string) (string, error) {
+ return "", nil
+}
+func (s *stubSoraClientForHandler) DeletePost(context.Context, *service.Account, string) error {
+ return nil
+}
+func (s *stubSoraClientForHandler) GetWatermarkFreeURLCustom(context.Context, *service.Account, string, string, string) (string, error) {
+ return "", nil
+}
+func (s *stubSoraClientForHandler) EnhancePrompt(context.Context, *service.Account, string, string, int) (string, error) {
+ return "", nil
+}
+func (s *stubSoraClientForHandler) GetImageTask(context.Context, *service.Account, string) (*service.SoraImageTaskStatus, error) {
+ return nil, nil
+}
+func (s *stubSoraClientForHandler) GetVideoTask(_ context.Context, _ *service.Account, _ string) (*service.SoraVideoTaskStatus, error) {
+ return s.videoStatus, nil
+}
+
+// ==================== 辅助:创建最小 GatewayService 和 SoraGatewayService ====================
+
+// newMinimalGatewayService 创建仅包含 accountRepo 的最小 GatewayService(用于测试 SelectAccountForModel)。
+func newMinimalGatewayService(accountRepo service.AccountRepository) *service.GatewayService {
+ return service.NewGatewayService(
+ accountRepo, nil, nil, nil, nil, nil, nil, nil,
+ nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
+ )
+}
+
+// newMinimalSoraGatewayService 创建最小 SoraGatewayService(用于测试 Forward)。
+func newMinimalSoraGatewayService(soraClient service.SoraClient) *service.SoraGatewayService {
+ cfg := &config.Config{
+ Sora: config.SoraConfig{
+ Client: config.SoraClientConfig{
+ PollIntervalSeconds: 1,
+ MaxPollAttempts: 1,
+ },
+ },
+ }
+ return service.NewSoraGatewayService(soraClient, nil, nil, cfg)
+}
+
+// ==================== processGeneration: 更多路径测试 ====================
+
+func TestProcessGeneration_SelectAccountError(t *testing.T) {
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+ // accountRepo 返回空列表 → SelectAccountForModel 返回 "no available accounts"
+ accountRepo := &stubAccountRepoForHandler{accounts: nil}
+ gatewayService := newMinimalGatewayService(accountRepo)
+ h := &SoraClientHandler{genService: genService, gatewayService: gatewayService}
+
+ h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1)
+ require.Equal(t, "failed", repo.gens[1].Status)
+ require.Contains(t, repo.gens[1].ErrorMessage, "选择账号失败")
+}
+
+func TestProcessGeneration_SoraGatewayServiceNil(t *testing.T) {
+ t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复")
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+ // 提供可用账号使 SelectAccountForModel 成功
+ accountRepo := &stubAccountRepoForHandler{
+ accounts: []service.Account{
+ {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true},
+ },
+ }
+ gatewayService := newMinimalGatewayService(accountRepo)
+ // soraGatewayService 为 nil
+ h := &SoraClientHandler{genService: genService, gatewayService: gatewayService}
+
+ h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1)
+ require.Equal(t, "failed", repo.gens[1].Status)
+ require.Contains(t, repo.gens[1].ErrorMessage, "soraGatewayService")
+}
+
+func TestProcessGeneration_ForwardError(t *testing.T) {
+ t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复")
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+ accountRepo := &stubAccountRepoForHandler{
+ accounts: []service.Account{
+ {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true},
+ },
+ }
+ gatewayService := newMinimalGatewayService(accountRepo)
+ // SoraClient 返回视频任务失败
+ soraClient := &stubSoraClientForHandler{
+ videoStatus: &service.SoraVideoTaskStatus{
+ Status: "failed",
+ ErrorMsg: "content policy violation",
+ },
+ }
+ soraGatewayService := newMinimalSoraGatewayService(soraClient)
+ h := &SoraClientHandler{
+ genService: genService,
+ gatewayService: gatewayService,
+ soraGatewayService: soraGatewayService,
+ }
+
+ h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test prompt", "video", "", 1)
+ require.Equal(t, "failed", repo.gens[1].Status)
+ require.Contains(t, repo.gens[1].ErrorMessage, "生成失败")
+}
+
+func TestProcessGeneration_ForwardErrorCancelled(t *testing.T) {
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
+ // MarkGenerating 内部调用 GetByID(第 1 次),Forward 失败后 processGeneration
+ // 调用 GetByID(第 2 次)。模拟外部在 Forward 期间取消了任务。
+ repo.getByIDOverrideAfterN = 1
+ repo.getByIDOverrideStatus = "cancelled"
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+ accountRepo := &stubAccountRepoForHandler{
+ accounts: []service.Account{
+ {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true},
+ },
+ }
+ gatewayService := newMinimalGatewayService(accountRepo)
+ soraClient := &stubSoraClientForHandler{
+ videoStatus: &service.SoraVideoTaskStatus{Status: "failed", ErrorMsg: "reject"},
+ }
+ soraGatewayService := newMinimalSoraGatewayService(soraClient)
+ h := &SoraClientHandler{
+ genService: genService,
+ gatewayService: gatewayService,
+ soraGatewayService: soraGatewayService,
+ }
+
+ h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1)
+ // Forward 失败后检测到外部取消,不应调用 MarkFailed(状态保持 generating)
+ require.Equal(t, "generating", repo.gens[1].Status)
+}
+
+func TestProcessGeneration_ForwardSuccessNoMediaURL(t *testing.T) {
+ t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复")
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+ accountRepo := &stubAccountRepoForHandler{
+ accounts: []service.Account{
+ {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true},
+ },
+ }
+ gatewayService := newMinimalGatewayService(accountRepo)
+ // SoraClient 返回 completed 但无 URL
+ soraClient := &stubSoraClientForHandler{
+ videoStatus: &service.SoraVideoTaskStatus{
+ Status: "completed",
+ URLs: nil, // 无 URL
+ },
+ }
+ soraGatewayService := newMinimalSoraGatewayService(soraClient)
+ h := &SoraClientHandler{
+ genService: genService,
+ gatewayService: gatewayService,
+ soraGatewayService: soraGatewayService,
+ }
+
+ h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1)
+ require.Equal(t, "failed", repo.gens[1].Status)
+ require.Contains(t, repo.gens[1].ErrorMessage, "未获取到媒体 URL")
+}
+
+func TestProcessGeneration_ForwardSuccessCancelledBeforeStore(t *testing.T) {
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
+ // MarkGenerating 调用 GetByID(第 1 次),之后 processGeneration 行 176 调用 GetByID(第 2 次)
+ // 第 2 次返回 "cancelled" 状态,模拟外部取消
+ repo.getByIDOverrideAfterN = 1
+ repo.getByIDOverrideStatus = "cancelled"
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+ accountRepo := &stubAccountRepoForHandler{
+ accounts: []service.Account{
+ {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true},
+ },
+ }
+ gatewayService := newMinimalGatewayService(accountRepo)
+ soraClient := &stubSoraClientForHandler{
+ videoStatus: &service.SoraVideoTaskStatus{
+ Status: "completed",
+ URLs: []string{"https://example.com/video.mp4"},
+ },
+ }
+ soraGatewayService := newMinimalSoraGatewayService(soraClient)
+ h := &SoraClientHandler{
+ genService: genService,
+ gatewayService: gatewayService,
+ soraGatewayService: soraGatewayService,
+ }
+
+ h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1)
+ // Forward 成功后检测到外部取消,不应调用存储和 MarkCompleted(状态保持 generating)
+ require.Equal(t, "generating", repo.gens[1].Status)
+}
+
+func TestProcessGeneration_FullSuccessUpstream(t *testing.T) {
+ t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复")
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+ accountRepo := &stubAccountRepoForHandler{
+ accounts: []service.Account{
+ {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true},
+ },
+ }
+ gatewayService := newMinimalGatewayService(accountRepo)
+ soraClient := &stubSoraClientForHandler{
+ videoStatus: &service.SoraVideoTaskStatus{
+ Status: "completed",
+ URLs: []string{"https://example.com/video.mp4"},
+ },
+ }
+ soraGatewayService := newMinimalSoraGatewayService(soraClient)
+ // 无 S3 和本地存储,降级到 upstream
+ h := &SoraClientHandler{
+ genService: genService,
+ gatewayService: gatewayService,
+ soraGatewayService: soraGatewayService,
+ }
+
+ h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test prompt", "video", "", 1)
+ require.Equal(t, "completed", repo.gens[1].Status)
+ require.Equal(t, service.SoraStorageTypeUpstream, repo.gens[1].StorageType)
+ require.NotEmpty(t, repo.gens[1].MediaURL)
+}
+
+func TestProcessGeneration_FullSuccessWithS3(t *testing.T) {
+ t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复")
+ sourceServer := newFakeSourceServer()
+ defer sourceServer.Close()
+ fakeS3 := newFakeS3Server("ok")
+ defer fakeS3.Close()
+
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+ accountRepo := &stubAccountRepoForHandler{
+ accounts: []service.Account{
+ {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true},
+ },
+ }
+ gatewayService := newMinimalGatewayService(accountRepo)
+ soraClient := &stubSoraClientForHandler{
+ videoStatus: &service.SoraVideoTaskStatus{
+ Status: "completed",
+ URLs: []string{sourceServer.URL + "/video.mp4"},
+ },
+ }
+ soraGatewayService := newMinimalSoraGatewayService(soraClient)
+ s3Storage := newS3StorageForHandler(fakeS3.URL)
+
+ userRepo := newStubUserRepoForHandler()
+ userRepo.users[1] = &service.User{
+ ID: 1, SoraStorageQuotaBytes: 100 * 1024 * 1024,
+ }
+ quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
+
+ h := &SoraClientHandler{
+ genService: genService,
+ gatewayService: gatewayService,
+ soraGatewayService: soraGatewayService,
+ s3Storage: s3Storage,
+ quotaService: quotaService,
+ }
+
+ h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test prompt", "video", "", 1)
+ require.Equal(t, "completed", repo.gens[1].Status)
+ require.Equal(t, service.SoraStorageTypeS3, repo.gens[1].StorageType)
+ require.NotEmpty(t, repo.gens[1].S3ObjectKeys)
+ require.Greater(t, repo.gens[1].FileSizeBytes, int64(0))
+ // 验证配额已累加
+ require.Greater(t, userRepo.users[1].SoraStorageUsedBytes, int64(0))
+}
+
+func TestProcessGeneration_MarkCompletedFails(t *testing.T) {
+ t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复")
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
+ // 第 1 次 Update(MarkGenerating)成功,第 2 次(MarkCompleted)失败
+ repo.updateCallCount = new(int32)
+ repo.updateFailAfterN = 1
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+ accountRepo := &stubAccountRepoForHandler{
+ accounts: []service.Account{
+ {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true},
+ },
+ }
+ gatewayService := newMinimalGatewayService(accountRepo)
+ soraClient := &stubSoraClientForHandler{
+ videoStatus: &service.SoraVideoTaskStatus{
+ Status: "completed",
+ URLs: []string{"https://example.com/video.mp4"},
+ },
+ }
+ soraGatewayService := newMinimalSoraGatewayService(soraClient)
+ h := &SoraClientHandler{
+ genService: genService,
+ gatewayService: gatewayService,
+ soraGatewayService: soraGatewayService,
+ }
+
+ h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test prompt", "video", "", 1)
+ // MarkCompleted 内部先修改内存对象状态为 completed,然后 Update 失败。
+ // 由于 stub 存储的是指针,内存中的状态已被修改为 completed。
+ // 此测试验证 processGeneration 在 MarkCompleted 失败后提前返回(不调用 AddUsage)。
+ require.Equal(t, "completed", repo.gens[1].Status)
+}
+
+// ==================== cleanupStoredMedia 直接测试 ====================
+
+func TestCleanupStoredMedia_S3Path(t *testing.T) {
+ // S3 清理路径:s3Storage 为 nil 时不 panic
+ h := &SoraClientHandler{}
+ // 不应 panic
+ h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeS3, []string{"key1"}, nil)
+}
+
+func TestCleanupStoredMedia_LocalPath(t *testing.T) {
+ // 本地清理路径:mediaStorage 为 nil 时不 panic
+ h := &SoraClientHandler{}
+ h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeLocal, nil, []string{"/tmp/test.mp4"})
+}
+
+func TestCleanupStoredMedia_UpstreamPath(t *testing.T) {
+ // upstream 类型不清理
+ h := &SoraClientHandler{}
+ h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeUpstream, nil, nil)
+}
+
+func TestCleanupStoredMedia_EmptyKeys(t *testing.T) {
+ // 空 keys 不触发清理
+ h := &SoraClientHandler{}
+ h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeS3, nil, nil)
+ h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeLocal, nil, nil)
+}
+
+// ==================== DeleteGeneration: 本地存储清理路径 ====================
+
+func TestDeleteGeneration_LocalStorageCleanup(t *testing.T) {
+ tmpDir, err := os.MkdirTemp("", "sora-delete-test-*")
+ require.NoError(t, err)
+ defer os.RemoveAll(tmpDir)
+
+ cfg := &config.Config{
+ Sora: config.SoraConfig{
+ Storage: config.SoraStorageConfig{
+ Type: "local",
+ LocalPath: tmpDir,
+ },
+ },
+ }
+ mediaStorage := service.NewSoraMediaStorage(cfg)
+
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{
+ ID: 1,
+ UserID: 1,
+ Status: "completed",
+ StorageType: service.SoraStorageTypeLocal,
+ MediaURL: "video/test.mp4",
+ MediaURLs: []string{"video/test.mp4"},
+ }
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+ h := &SoraClientHandler{genService: genService, mediaStorage: mediaStorage}
+
+ c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1)
+ c.Params = gin.Params{{Key: "id", Value: "1"}}
+ h.DeleteGeneration(c)
+ require.Equal(t, http.StatusOK, rec.Code)
+ _, exists := repo.gens[1]
+ require.False(t, exists)
+}
+
+func TestDeleteGeneration_LocalStorageCleanup_MediaURLFallback(t *testing.T) {
+ // MediaURLs 为空,使用 MediaURL 作为清理路径
+ tmpDir, err := os.MkdirTemp("", "sora-delete-fallback-*")
+ require.NoError(t, err)
+ defer os.RemoveAll(tmpDir)
+
+ cfg := &config.Config{
+ Sora: config.SoraConfig{
+ Storage: config.SoraStorageConfig{
+ Type: "local",
+ LocalPath: tmpDir,
+ },
+ },
+ }
+ mediaStorage := service.NewSoraMediaStorage(cfg)
+
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{
+ ID: 1,
+ UserID: 1,
+ Status: "completed",
+ StorageType: service.SoraStorageTypeLocal,
+ MediaURL: "video/test.mp4",
+ MediaURLs: nil, // 空
+ }
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+ h := &SoraClientHandler{genService: genService, mediaStorage: mediaStorage}
+
+ c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1)
+ c.Params = gin.Params{{Key: "id", Value: "1"}}
+ h.DeleteGeneration(c)
+ require.Equal(t, http.StatusOK, rec.Code)
+}
+
+func TestDeleteGeneration_NonLocalStorage_SkipCleanup(t *testing.T) {
+ // 非本地存储类型 → 跳过清理
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{
+ ID: 1,
+ UserID: 1,
+ Status: "completed",
+ StorageType: service.SoraStorageTypeUpstream,
+ MediaURL: "https://upstream.com/v.mp4",
+ }
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+ h := &SoraClientHandler{genService: genService}
+
+ c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1)
+ c.Params = gin.Params{{Key: "id", Value: "1"}}
+ h.DeleteGeneration(c)
+ require.Equal(t, http.StatusOK, rec.Code)
+}
+
+func TestDeleteGeneration_DeleteError(t *testing.T) {
+ // repo.Delete 出错
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed", StorageType: "upstream"}
+ repo.deleteErr = fmt.Errorf("delete failed")
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+ h := &SoraClientHandler{genService: genService}
+
+ c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1)
+ c.Params = gin.Params{{Key: "id", Value: "1"}}
+ h.DeleteGeneration(c)
+ require.Equal(t, http.StatusNotFound, rec.Code)
+}
+
+// ==================== fetchUpstreamModels 测试 ====================
+
+func TestFetchUpstreamModels_NilGateway(t *testing.T) {
+ t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
+ h := &SoraClientHandler{}
+ _, err := h.fetchUpstreamModels(context.Background())
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "gatewayService 未初始化")
+}
+
+func TestFetchUpstreamModels_NoAccounts(t *testing.T) {
+ t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
+ accountRepo := &stubAccountRepoForHandler{accounts: nil}
+ gatewayService := newMinimalGatewayService(accountRepo)
+ h := &SoraClientHandler{gatewayService: gatewayService}
+ _, err := h.fetchUpstreamModels(context.Background())
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "选择 Sora 账号失败")
+}
+
+func TestFetchUpstreamModels_NonAPIKeyAccount(t *testing.T) {
+ t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
+ accountRepo := &stubAccountRepoForHandler{
+ accounts: []service.Account{
+ {ID: 1, Type: "oauth", Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true},
+ },
+ }
+ gatewayService := newMinimalGatewayService(accountRepo)
+ h := &SoraClientHandler{gatewayService: gatewayService}
+ _, err := h.fetchUpstreamModels(context.Background())
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "不支持模型同步")
+}
+
+func TestFetchUpstreamModels_MissingAPIKey(t *testing.T) {
+ t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
+ accountRepo := &stubAccountRepoForHandler{
+ accounts: []service.Account{
+ {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true,
+ Credentials: map[string]any{"base_url": "https://sora.test"}},
+ },
+ }
+ gatewayService := newMinimalGatewayService(accountRepo)
+ h := &SoraClientHandler{gatewayService: gatewayService}
+ _, err := h.fetchUpstreamModels(context.Background())
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "api_key")
+}
+
+func TestFetchUpstreamModels_MissingBaseURL_FallsBackToDefault(t *testing.T) {
+ t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
+ // GetBaseURL() 在缺少 base_url 时返回默认值 "https://api.anthropic.com"
+ // 因此不会触发 "账号缺少 base_url" 错误,而是会尝试请求默认 URL 并失败
+ accountRepo := &stubAccountRepoForHandler{
+ accounts: []service.Account{
+ {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true,
+ Credentials: map[string]any{"api_key": "sk-test"}},
+ },
+ }
+ gatewayService := newMinimalGatewayService(accountRepo)
+ h := &SoraClientHandler{gatewayService: gatewayService}
+ _, err := h.fetchUpstreamModels(context.Background())
+ require.Error(t, err)
+}
+
+func TestFetchUpstreamModels_UpstreamReturns500(t *testing.T) {
+ t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
+ ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusInternalServerError)
+ }))
+ defer ts.Close()
+
+ accountRepo := &stubAccountRepoForHandler{
+ accounts: []service.Account{
+ {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true,
+ Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}},
+ },
+ }
+ gatewayService := newMinimalGatewayService(accountRepo)
+ h := &SoraClientHandler{gatewayService: gatewayService}
+ _, err := h.fetchUpstreamModels(context.Background())
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "状态码 500")
+}
+
+func TestFetchUpstreamModels_UpstreamReturnsInvalidJSON(t *testing.T) {
+ t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
+ ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusOK)
+ _, _ = w.Write([]byte("not json"))
+ }))
+ defer ts.Close()
+
+ accountRepo := &stubAccountRepoForHandler{
+ accounts: []service.Account{
+ {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true,
+ Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}},
+ },
+ }
+ gatewayService := newMinimalGatewayService(accountRepo)
+ h := &SoraClientHandler{gatewayService: gatewayService}
+ _, err := h.fetchUpstreamModels(context.Background())
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "解析响应失败")
+}
+
+func TestFetchUpstreamModels_UpstreamReturnsEmptyList(t *testing.T) {
+ t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
+ ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusOK)
+ _, _ = w.Write([]byte(`{"data":[]}`))
+ }))
+ defer ts.Close()
+
+ accountRepo := &stubAccountRepoForHandler{
+ accounts: []service.Account{
+ {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true,
+ Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}},
+ },
+ }
+ gatewayService := newMinimalGatewayService(accountRepo)
+ h := &SoraClientHandler{gatewayService: gatewayService}
+ _, err := h.fetchUpstreamModels(context.Background())
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "空模型列表")
+}
+
+func TestFetchUpstreamModels_Success(t *testing.T) {
+ t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
+ ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ // 验证请求头
+ require.Equal(t, "Bearer sk-test", r.Header.Get("Authorization"))
+ require.True(t, strings.HasSuffix(r.URL.Path, "/sora/v1/models"))
+ w.WriteHeader(http.StatusOK)
+ _, _ = w.Write([]byte(`{"data":[{"id":"sora2-landscape-10s"},{"id":"sora2-portrait-10s"},{"id":"sora2-landscape-15s"},{"id":"gpt-image"}]}`))
+ }))
+ defer ts.Close()
+
+ accountRepo := &stubAccountRepoForHandler{
+ accounts: []service.Account{
+ {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true,
+ Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}},
+ },
+ }
+ gatewayService := newMinimalGatewayService(accountRepo)
+ h := &SoraClientHandler{gatewayService: gatewayService}
+ families, err := h.fetchUpstreamModels(context.Background())
+ require.NoError(t, err)
+ require.NotEmpty(t, families)
+}
+
+func TestFetchUpstreamModels_UnrecognizedModels(t *testing.T) {
+ t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
+ ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusOK)
+ _, _ = w.Write([]byte(`{"data":[{"id":"unknown-model-1"},{"id":"unknown-model-2"}]}`))
+ }))
+ defer ts.Close()
+
+ accountRepo := &stubAccountRepoForHandler{
+ accounts: []service.Account{
+ {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true,
+ Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}},
+ },
+ }
+ gatewayService := newMinimalGatewayService(accountRepo)
+ h := &SoraClientHandler{gatewayService: gatewayService}
+ _, err := h.fetchUpstreamModels(context.Background())
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "未能从上游模型列表中识别")
+}
+
+// ==================== getModelFamilies 缓存测试 ====================
+
+func TestGetModelFamilies_CachesLocalConfig(t *testing.T) {
+ // gatewayService 为 nil → fetchUpstreamModels 失败 → 降级到本地配置
+ h := &SoraClientHandler{}
+ families := h.getModelFamilies(context.Background())
+ require.NotEmpty(t, families)
+
+ // 第二次调用应命中缓存(modelCacheUpstream=false → 使用短 TTL)
+ families2 := h.getModelFamilies(context.Background())
+ require.Equal(t, families, families2)
+ require.False(t, h.modelCacheUpstream)
+}
+
+func TestGetModelFamilies_CachesUpstreamResult(t *testing.T) {
+ t.Skip("TODO: 临时屏蔽依赖 Sora 上游模型同步的缓存测试,待账号选择逻辑稳定后恢复")
+ ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusOK)
+ _, _ = w.Write([]byte(`{"data":[{"id":"sora2-landscape-10s"},{"id":"gpt-image"}]}`))
+ }))
+ defer ts.Close()
+
+ accountRepo := &stubAccountRepoForHandler{
+ accounts: []service.Account{
+ {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true,
+ Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}},
+ },
+ }
+ gatewayService := newMinimalGatewayService(accountRepo)
+ h := &SoraClientHandler{gatewayService: gatewayService}
+
+ families := h.getModelFamilies(context.Background())
+ require.NotEmpty(t, families)
+ require.True(t, h.modelCacheUpstream)
+
+ // 第二次调用命中缓存
+ families2 := h.getModelFamilies(context.Background())
+ require.Equal(t, families, families2)
+}
+
+func TestGetModelFamilies_ExpiredCacheRefreshes(t *testing.T) {
+ // 预设过期的缓存(modelCacheUpstream=false → 短 TTL)
+ h := &SoraClientHandler{
+ cachedFamilies: []service.SoraModelFamily{{ID: "old"}},
+ modelCacheTime: time.Now().Add(-10 * time.Minute), // 已过期
+ modelCacheUpstream: false,
+ }
+ // gatewayService 为 nil → fetchUpstreamModels 失败 → 使用本地配置刷新缓存
+ families := h.getModelFamilies(context.Background())
+ require.NotEmpty(t, families)
+ // 缓存已刷新,不再是 "old"
+ found := false
+ for _, f := range families {
+ if f.ID == "old" {
+ found = true
+ }
+ }
+ require.False(t, found, "过期缓存应被刷新")
+}
+
+// ==================== processGeneration: groupID 与 ForcePlatform ====================
+
+func TestProcessGeneration_NilGroupID_WithGateway_SelectAccountFails(t *testing.T) {
+ // groupID 为 nil → 设置 ForcePlatform=sora → 无可用 sora 账号 → MarkFailed
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+
+ // 空账号列表 → SelectAccountForModel 失败
+ accountRepo := &stubAccountRepoForHandler{accounts: nil}
+ gatewayService := newMinimalGatewayService(accountRepo)
+
+ h := &SoraClientHandler{
+ genService: genService,
+ gatewayService: gatewayService,
+ }
+
+ h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1)
+ require.Equal(t, "failed", repo.gens[1].Status)
+ require.Contains(t, repo.gens[1].ErrorMessage, "选择账号失败")
+}
+
+// ==================== Generate: 配额检查非 QuotaExceeded 错误 ====================
+
+func TestGenerate_CheckQuotaNonQuotaError(t *testing.T) {
+ // quotaService.CheckQuota 返回非 QuotaExceededError → 返回 403
+ repo := newStubSoraGenRepo()
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+
+ // 用户不存在 → GetByID 失败 → CheckQuota 返回普通 error
+ userRepo := newStubUserRepoForHandler()
+ quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
+
+ h := NewSoraClientHandler(genService, quotaService, nil, nil, nil, nil, nil)
+
+ body := `{"model":"sora2-landscape-10s","prompt":"test"}`
+ c, rec := makeGinContext("POST", "/api/v1/sora/generate", body, 1)
+ h.Generate(c)
+ require.Equal(t, http.StatusForbidden, rec.Code)
+}
+
+// ==================== Generate: CreatePending 并发限制错误 ====================
+
+// stubSoraGenRepoWithAtomicCreate 实现 soraGenerationRepoAtomicCreator 接口
+type stubSoraGenRepoWithAtomicCreate struct {
+ stubSoraGenRepo
+ limitErr error
+}
+
+func (r *stubSoraGenRepoWithAtomicCreate) CreatePendingWithLimit(_ context.Context, gen *service.SoraGeneration, _ []string, _ int64) error {
+ if r.limitErr != nil {
+ return r.limitErr
+ }
+ return r.stubSoraGenRepo.Create(context.Background(), gen)
+}
+
+func TestGenerate_CreatePendingConcurrencyLimit(t *testing.T) {
+ repo := &stubSoraGenRepoWithAtomicCreate{
+ stubSoraGenRepo: *newStubSoraGenRepo(),
+ limitErr: service.ErrSoraGenerationConcurrencyLimit,
+ }
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+ h := NewSoraClientHandler(genService, nil, nil, nil, nil, nil, nil)
+
+ body := `{"model":"sora2-landscape-10s","prompt":"test"}`
+ c, rec := makeGinContext("POST", "/api/v1/sora/generate", body, 1)
+ h.Generate(c)
+ require.Equal(t, http.StatusTooManyRequests, rec.Code)
+ resp := parseResponse(t, rec)
+ require.Contains(t, resp["message"], "3")
+}
+
+// ==================== SaveToStorage: 配额超限 ====================
+
+func TestSaveToStorage_QuotaExceeded(t *testing.T) {
+ sourceServer := newFakeSourceServer()
+ defer sourceServer.Close()
+ fakeS3 := newFakeS3Server("ok")
+ defer fakeS3.Close()
+
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{
+ ID: 1, UserID: 1, Status: "completed",
+ StorageType: "upstream",
+ MediaURL: sourceServer.URL + "/v.mp4",
+ }
+ s3Storage := newS3StorageForHandler(fakeS3.URL)
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+
+ // 用户配额已满
+ userRepo := newStubUserRepoForHandler()
+ userRepo.users[1] = &service.User{
+ ID: 1,
+ SoraStorageQuotaBytes: 10,
+ SoraStorageUsedBytes: 10,
+ }
+ quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
+ h := &SoraClientHandler{genService: genService, s3Storage: s3Storage, quotaService: quotaService}
+
+ c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
+ c.Params = gin.Params{{Key: "id", Value: "1"}}
+ h.SaveToStorage(c)
+ require.Equal(t, http.StatusTooManyRequests, rec.Code)
+}
+
+// ==================== SaveToStorage: 配额非 QuotaExceeded 错误 ====================
+
+func TestSaveToStorage_QuotaNonQuotaError(t *testing.T) {
+ sourceServer := newFakeSourceServer()
+ defer sourceServer.Close()
+ fakeS3 := newFakeS3Server("ok")
+ defer fakeS3.Close()
+
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{
+ ID: 1, UserID: 1, Status: "completed",
+ StorageType: "upstream",
+ MediaURL: sourceServer.URL + "/v.mp4",
+ }
+ s3Storage := newS3StorageForHandler(fakeS3.URL)
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+
+ // 用户不存在 → GetByID 失败 → AddUsage 返回普通 error
+ userRepo := newStubUserRepoForHandler()
+ quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
+ h := &SoraClientHandler{genService: genService, s3Storage: s3Storage, quotaService: quotaService}
+
+ c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
+ c.Params = gin.Params{{Key: "id", Value: "1"}}
+ h.SaveToStorage(c)
+ require.Equal(t, http.StatusInternalServerError, rec.Code)
+}
+
+// ==================== SaveToStorage: MediaURLs 全为空 ====================
+
+func TestSaveToStorage_EmptyMediaURLs(t *testing.T) {
+ fakeS3 := newFakeS3Server("ok")
+ defer fakeS3.Close()
+
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{
+ ID: 1, UserID: 1, Status: "completed",
+ StorageType: "upstream",
+ MediaURL: "",
+ MediaURLs: []string{},
+ }
+ s3Storage := newS3StorageForHandler(fakeS3.URL)
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+ h := &SoraClientHandler{genService: genService, s3Storage: s3Storage}
+
+ c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
+ c.Params = gin.Params{{Key: "id", Value: "1"}}
+ h.SaveToStorage(c)
+ require.Equal(t, http.StatusBadRequest, rec.Code)
+ resp := parseResponse(t, rec)
+ require.Contains(t, resp["message"], "已过期")
+}
+
+// ==================== SaveToStorage: S3 上传失败时已有已上传文件需清理 ====================
+
+func TestSaveToStorage_MultiURL_SecondUploadFails(t *testing.T) {
+ sourceServer := newFakeSourceServer()
+ defer sourceServer.Close()
+ fakeS3 := newFakeS3Server("fail-second")
+ defer fakeS3.Close()
+
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{
+ ID: 1, UserID: 1, Status: "completed",
+ StorageType: "upstream",
+ MediaURL: sourceServer.URL + "/v1.mp4",
+ MediaURLs: []string{sourceServer.URL + "/v1.mp4", sourceServer.URL + "/v2.mp4"},
+ }
+ s3Storage := newS3StorageForHandler(fakeS3.URL)
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+ h := &SoraClientHandler{genService: genService, s3Storage: s3Storage}
+
+ c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
+ c.Params = gin.Params{{Key: "id", Value: "1"}}
+ h.SaveToStorage(c)
+ require.Equal(t, http.StatusInternalServerError, rec.Code)
+}
+
+// ==================== SaveToStorage: UpdateStorageForCompleted 失败(含配额回滚) ====================
+
+func TestSaveToStorage_MarkCompletedFailsWithQuotaRollback(t *testing.T) {
+ sourceServer := newFakeSourceServer()
+ defer sourceServer.Close()
+ fakeS3 := newFakeS3Server("ok")
+ defer fakeS3.Close()
+
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{
+ ID: 1, UserID: 1, Status: "completed",
+ StorageType: "upstream",
+ MediaURL: sourceServer.URL + "/v.mp4",
+ }
+ repo.updateErr = fmt.Errorf("db error")
+ s3Storage := newS3StorageForHandler(fakeS3.URL)
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+
+ userRepo := newStubUserRepoForHandler()
+ userRepo.users[1] = &service.User{
+ ID: 1,
+ SoraStorageQuotaBytes: 100 * 1024 * 1024,
+ SoraStorageUsedBytes: 0,
+ }
+ quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
+ h := &SoraClientHandler{genService: genService, s3Storage: s3Storage, quotaService: quotaService}
+
+ c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
+ c.Params = gin.Params{{Key: "id", Value: "1"}}
+ h.SaveToStorage(c)
+ require.Equal(t, http.StatusInternalServerError, rec.Code)
+}
+
+// ==================== cleanupStoredMedia: 实际 S3 删除路径 ====================
+
+func TestCleanupStoredMedia_WithS3Storage_ActualDelete(t *testing.T) {
+ fakeS3 := newFakeS3Server("ok")
+ defer fakeS3.Close()
+ s3Storage := newS3StorageForHandler(fakeS3.URL)
+ h := &SoraClientHandler{s3Storage: s3Storage}
+
+ h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeS3, []string{"key1", "key2"}, nil)
+}
+
+func TestCleanupStoredMedia_S3DeleteFails_LogOnly(t *testing.T) {
+ fakeS3 := newFakeS3Server("fail")
+ defer fakeS3.Close()
+ s3Storage := newS3StorageForHandler(fakeS3.URL)
+ h := &SoraClientHandler{s3Storage: s3Storage}
+
+ h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeS3, []string{"key1"}, nil)
+}
+
+func TestCleanupStoredMedia_LocalDeleteFails_LogOnly(t *testing.T) {
+ tmpDir, err := os.MkdirTemp("", "sora-cleanup-fail-*")
+ require.NoError(t, err)
+ defer os.RemoveAll(tmpDir)
+
+ cfg := &config.Config{
+ Sora: config.SoraConfig{
+ Storage: config.SoraStorageConfig{
+ Type: "local",
+ LocalPath: tmpDir,
+ },
+ },
+ }
+ mediaStorage := service.NewSoraMediaStorage(cfg)
+ h := &SoraClientHandler{mediaStorage: mediaStorage}
+
+ h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeLocal, nil, []string{"nonexistent/file.mp4"})
+}
+
+// ==================== DeleteGeneration: 本地文件删除失败(仅日志) ====================
+
+func TestDeleteGeneration_LocalStorageDeleteFails_LogOnly(t *testing.T) {
+ tmpDir, err := os.MkdirTemp("", "sora-del-test-*")
+ require.NoError(t, err)
+ defer os.RemoveAll(tmpDir)
+
+ cfg := &config.Config{
+ Sora: config.SoraConfig{
+ Storage: config.SoraStorageConfig{
+ Type: "local",
+ LocalPath: tmpDir,
+ },
+ },
+ }
+ mediaStorage := service.NewSoraMediaStorage(cfg)
+
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{
+ ID: 1, UserID: 1, Status: "completed",
+ StorageType: service.SoraStorageTypeLocal,
+ MediaURL: "nonexistent/video.mp4",
+ MediaURLs: []string{"nonexistent/video.mp4"},
+ }
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+ h := &SoraClientHandler{genService: genService, mediaStorage: mediaStorage}
+
+ c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1)
+ c.Params = gin.Params{{Key: "id", Value: "1"}}
+ h.DeleteGeneration(c)
+ require.Equal(t, http.StatusOK, rec.Code)
+}
+
+// ==================== CancelGeneration: 任务已结束冲突 ====================
+
+func TestCancelGeneration_AlreadyCompleted(t *testing.T) {
+ repo := newStubSoraGenRepo()
+ repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed"}
+ genService := service.NewSoraGenerationService(repo, nil, nil)
+ h := &SoraClientHandler{genService: genService}
+
+ c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1)
+ c.Params = gin.Params{{Key: "id", Value: "1"}}
+ h.CancelGeneration(c)
+ require.Equal(t, http.StatusConflict, rec.Code)
+}
diff --git a/backend/internal/handler/sora_gateway_handler.go b/backend/internal/handler/sora_gateway_handler.go
index ab3a3f14..48c1e451 100644
--- a/backend/internal/handler/sora_gateway_handler.go
+++ b/backend/internal/handler/sora_gateway_handler.go
@@ -7,7 +7,6 @@ import (
"encoding/json"
"errors"
"fmt"
- "io"
"net/http"
"os"
"path"
@@ -17,6 +16,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
+ pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
@@ -107,7 +107,7 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
zap.Any("group_id", apiKey.GroupID),
)
- body, err := io.ReadAll(c.Request.Body)
+ body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
if err != nil {
if maxErr, ok := extractMaxBytesError(err); ok {
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
@@ -461,6 +461,14 @@ func (h *SoraGatewayHandler) submitUsageRecordTask(task service.UsageRecordTask)
// 回退路径:worker 池未注入时同步执行,避免退回到无界 goroutine 模式。
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
+ defer func() {
+ if recovered := recover(); recovered != nil {
+ logger.L().With(
+ zap.String("component", "handler.sora_gateway.chat_completions"),
+ zap.Any("panic", recovered),
+ ).Error("sora.usage_record_task_panic_recovered")
+ }
+ }()
task(ctx)
}
diff --git a/backend/internal/handler/sora_gateway_handler_test.go b/backend/internal/handler/sora_gateway_handler_test.go
index cc792350..355cdb7a 100644
--- a/backend/internal/handler/sora_gateway_handler_test.go
+++ b/backend/internal/handler/sora_gateway_handler_test.go
@@ -314,10 +314,13 @@ func (s *stubUsageLogRepo) GetAccountTodayStats(ctx context.Context, accountID i
func (s *stubUsageLogRepo) GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) {
return nil, nil
}
-func (s *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) {
+func (s *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) {
return nil, nil
}
-func (s *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) {
+func (s *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) {
+ return nil, nil
+}
+func (s *stubUsageLogRepo) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) {
return nil, nil
}
func (s *stubUsageLogRepo) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) {
@@ -426,7 +429,8 @@ func TestSoraGatewayHandler_ChatCompletions(t *testing.T) {
deferredService,
nil,
testutil.StubSessionLimitCache{},
- nil,
+ nil, // rpmCache
+ nil, // digestStore
)
soraClient := &stubSoraClient{imageURLs: []string{"https://example.com/a.png"}}
diff --git a/backend/internal/handler/usage_handler.go b/backend/internal/handler/usage_handler.go
index b8182dad..2bd0e0d7 100644
--- a/backend/internal/handler/usage_handler.go
+++ b/backend/internal/handler/usage_handler.go
@@ -2,6 +2,7 @@ package handler
import (
"strconv"
+ "strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
@@ -65,8 +66,17 @@ func (h *UsageHandler) List(c *gin.Context) {
// Parse additional filters
model := c.Query("model")
+ var requestType *int16
var stream *bool
- if streamStr := c.Query("stream"); streamStr != "" {
+ if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" {
+ parsed, err := service.ParseUsageRequestType(requestTypeStr)
+ if err != nil {
+ response.BadRequest(c, err.Error())
+ return
+ }
+ value := int16(parsed)
+ requestType = &value
+ } else if streamStr := c.Query("stream"); streamStr != "" {
val, err := strconv.ParseBool(streamStr)
if err != nil {
response.BadRequest(c, "Invalid stream value, use true or false")
@@ -114,6 +124,7 @@ func (h *UsageHandler) List(c *gin.Context) {
UserID: subject.UserID, // Always filter by current user for security
APIKeyID: apiKeyID,
Model: model,
+ RequestType: requestType,
Stream: stream,
BillingType: billingType,
StartTime: startTime,
diff --git a/backend/internal/handler/usage_handler_request_type_test.go b/backend/internal/handler/usage_handler_request_type_test.go
new file mode 100644
index 00000000..7c4c7913
--- /dev/null
+++ b/backend/internal/handler/usage_handler_request_type_test.go
@@ -0,0 +1,80 @@
+package handler
+
+import (
+ "context"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
+ middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+)
+
+type userUsageRepoCapture struct {
+ service.UsageLogRepository
+ listFilters usagestats.UsageLogFilters
+}
+
+func (s *userUsageRepoCapture) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]service.UsageLog, *pagination.PaginationResult, error) {
+ s.listFilters = filters
+ return []service.UsageLog{}, &pagination.PaginationResult{
+ Total: 0,
+ Page: params.Page,
+ PageSize: params.PageSize,
+ Pages: 0,
+ }, nil
+}
+
+func newUserUsageRequestTypeTestRouter(repo *userUsageRepoCapture) *gin.Engine {
+ gin.SetMode(gin.TestMode)
+ usageSvc := service.NewUsageService(repo, nil, nil, nil)
+ handler := NewUsageHandler(usageSvc, nil)
+ router := gin.New()
+ router.Use(func(c *gin.Context) {
+ c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 42})
+ c.Next()
+ })
+ router.GET("/usage", handler.List)
+ return router
+}
+
+func TestUserUsageListRequestTypePriority(t *testing.T) {
+ repo := &userUsageRepoCapture{}
+ router := newUserUsageRequestTypeTestRouter(repo)
+
+ req := httptest.NewRequest(http.MethodGet, "/usage?request_type=ws_v2&stream=bad", nil)
+ rec := httptest.NewRecorder()
+ router.ServeHTTP(rec, req)
+
+ require.Equal(t, http.StatusOK, rec.Code)
+ require.Equal(t, int64(42), repo.listFilters.UserID)
+ require.NotNil(t, repo.listFilters.RequestType)
+ require.Equal(t, int16(service.RequestTypeWSV2), *repo.listFilters.RequestType)
+ require.Nil(t, repo.listFilters.Stream)
+}
+
+func TestUserUsageListInvalidRequestType(t *testing.T) {
+ repo := &userUsageRepoCapture{}
+ router := newUserUsageRequestTypeTestRouter(repo)
+
+ req := httptest.NewRequest(http.MethodGet, "/usage?request_type=invalid", nil)
+ rec := httptest.NewRecorder()
+ router.ServeHTTP(rec, req)
+
+ require.Equal(t, http.StatusBadRequest, rec.Code)
+}
+
+func TestUserUsageListInvalidStream(t *testing.T) {
+ repo := &userUsageRepoCapture{}
+ router := newUserUsageRequestTypeTestRouter(repo)
+
+ req := httptest.NewRequest(http.MethodGet, "/usage?stream=invalid", nil)
+ rec := httptest.NewRecorder()
+ router.ServeHTTP(rec, req)
+
+ require.Equal(t, http.StatusBadRequest, rec.Code)
+}
diff --git a/backend/internal/handler/usage_record_submit_task_test.go b/backend/internal/handler/usage_record_submit_task_test.go
index df759f44..c7c48e14 100644
--- a/backend/internal/handler/usage_record_submit_task_test.go
+++ b/backend/internal/handler/usage_record_submit_task_test.go
@@ -61,6 +61,22 @@ func TestGatewayHandlerSubmitUsageRecordTask_NilTask(t *testing.T) {
})
}
+func TestGatewayHandlerSubmitUsageRecordTask_WithoutPool_TaskPanicRecovered(t *testing.T) {
+ h := &GatewayHandler{}
+ var called atomic.Bool
+
+ require.NotPanics(t, func() {
+ h.submitUsageRecordTask(func(ctx context.Context) {
+ panic("usage task panic")
+ })
+ })
+
+ h.submitUsageRecordTask(func(ctx context.Context) {
+ called.Store(true)
+ })
+ require.True(t, called.Load(), "panic 后后续任务应仍可执行")
+}
+
func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithPool(t *testing.T) {
pool := newUsageRecordTestPool(t)
h := &OpenAIGatewayHandler{usageRecordWorkerPool: pool}
@@ -98,6 +114,22 @@ func TestOpenAIGatewayHandlerSubmitUsageRecordTask_NilTask(t *testing.T) {
})
}
+func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithoutPool_TaskPanicRecovered(t *testing.T) {
+ h := &OpenAIGatewayHandler{}
+ var called atomic.Bool
+
+ require.NotPanics(t, func() {
+ h.submitUsageRecordTask(func(ctx context.Context) {
+ panic("usage task panic")
+ })
+ })
+
+ h.submitUsageRecordTask(func(ctx context.Context) {
+ called.Store(true)
+ })
+ require.True(t, called.Load(), "panic 后后续任务应仍可执行")
+}
+
func TestSoraGatewayHandlerSubmitUsageRecordTask_WithPool(t *testing.T) {
pool := newUsageRecordTestPool(t)
h := &SoraGatewayHandler{usageRecordWorkerPool: pool}
@@ -134,3 +166,19 @@ func TestSoraGatewayHandlerSubmitUsageRecordTask_NilTask(t *testing.T) {
h.submitUsageRecordTask(nil)
})
}
+
+func TestSoraGatewayHandlerSubmitUsageRecordTask_WithoutPool_TaskPanicRecovered(t *testing.T) {
+ h := &SoraGatewayHandler{}
+ var called atomic.Bool
+
+ require.NotPanics(t, func() {
+ h.submitUsageRecordTask(func(ctx context.Context) {
+ panic("usage task panic")
+ })
+ })
+
+ h.submitUsageRecordTask(func(ctx context.Context) {
+ called.Store(true)
+ })
+ require.True(t, called.Load(), "panic 后后续任务应仍可执行")
+}
diff --git a/backend/internal/handler/wire.go b/backend/internal/handler/wire.go
index 79d583fd..76f5a979 100644
--- a/backend/internal/handler/wire.go
+++ b/backend/internal/handler/wire.go
@@ -14,6 +14,7 @@ func ProvideAdminHandlers(
groupHandler *admin.GroupHandler,
accountHandler *admin.AccountHandler,
announcementHandler *admin.AnnouncementHandler,
+ dataManagementHandler *admin.DataManagementHandler,
oauthHandler *admin.OAuthHandler,
openaiOAuthHandler *admin.OpenAIOAuthHandler,
geminiOAuthHandler *admin.GeminiOAuthHandler,
@@ -28,6 +29,7 @@ func ProvideAdminHandlers(
usageHandler *admin.UsageHandler,
userAttributeHandler *admin.UserAttributeHandler,
errorPassthroughHandler *admin.ErrorPassthroughHandler,
+ apiKeyHandler *admin.AdminAPIKeyHandler,
) *AdminHandlers {
return &AdminHandlers{
Dashboard: dashboardHandler,
@@ -35,6 +37,7 @@ func ProvideAdminHandlers(
Group: groupHandler,
Account: accountHandler,
Announcement: announcementHandler,
+ DataManagement: dataManagementHandler,
OAuth: oauthHandler,
OpenAIOAuth: openaiOAuthHandler,
GeminiOAuth: geminiOAuthHandler,
@@ -49,6 +52,7 @@ func ProvideAdminHandlers(
Usage: usageHandler,
UserAttribute: userAttributeHandler,
ErrorPassthrough: errorPassthroughHandler,
+ APIKey: apiKeyHandler,
}
}
@@ -75,6 +79,7 @@ func ProvideHandlers(
gatewayHandler *GatewayHandler,
openaiGatewayHandler *OpenAIGatewayHandler,
soraGatewayHandler *SoraGatewayHandler,
+ soraClientHandler *SoraClientHandler,
settingHandler *SettingHandler,
totpHandler *TotpHandler,
_ *service.IdempotencyCoordinator,
@@ -92,6 +97,7 @@ func ProvideHandlers(
Gateway: gatewayHandler,
OpenAIGateway: openaiGatewayHandler,
SoraGateway: soraGatewayHandler,
+ SoraClient: soraClientHandler,
Setting: settingHandler,
Totp: totpHandler,
}
@@ -119,6 +125,7 @@ var ProviderSet = wire.NewSet(
admin.NewGroupHandler,
admin.NewAccountHandler,
admin.NewAnnouncementHandler,
+ admin.NewDataManagementHandler,
admin.NewOAuthHandler,
admin.NewOpenAIOAuthHandler,
admin.NewGeminiOAuthHandler,
@@ -133,6 +140,7 @@ var ProviderSet = wire.NewSet(
admin.NewUsageHandler,
admin.NewUserAttributeHandler,
admin.NewErrorPassthroughHandler,
+ admin.NewAdminAPIKeyHandler,
// AdminHandlers and Handlers constructors
ProvideAdminHandlers,
diff --git a/backend/internal/pkg/antigravity/claude_types.go b/backend/internal/pkg/antigravity/claude_types.go
index 7c127b90..7cc68060 100644
--- a/backend/internal/pkg/antigravity/claude_types.go
+++ b/backend/internal/pkg/antigravity/claude_types.go
@@ -151,6 +151,9 @@ var claudeModels = []modelDef{
{ID: "claude-opus-4-5-thinking", DisplayName: "Claude Opus 4.5 Thinking", CreatedAt: "2025-11-01T00:00:00Z"},
{ID: "claude-sonnet-4-5", DisplayName: "Claude Sonnet 4.5", CreatedAt: "2025-09-29T00:00:00Z"},
{ID: "claude-sonnet-4-5-thinking", DisplayName: "Claude Sonnet 4.5 Thinking", CreatedAt: "2025-09-29T00:00:00Z"},
+ {ID: "claude-opus-4-6", DisplayName: "Claude Opus 4.6", CreatedAt: "2026-02-05T00:00:00Z"},
+ {ID: "claude-opus-4-6-thinking", DisplayName: "Claude Opus 4.6 Thinking", CreatedAt: "2026-02-05T00:00:00Z"},
+ {ID: "claude-sonnet-4-6", DisplayName: "Claude Sonnet 4.6", CreatedAt: "2026-02-17T00:00:00Z"},
}
// Antigravity 支持的 Gemini 模型
@@ -161,6 +164,10 @@ var geminiModels = []modelDef{
{ID: "gemini-3-flash", DisplayName: "Gemini 3 Flash", CreatedAt: "2025-06-01T00:00:00Z"},
{ID: "gemini-3-pro-low", DisplayName: "Gemini 3 Pro Low", CreatedAt: "2025-06-01T00:00:00Z"},
{ID: "gemini-3-pro-high", DisplayName: "Gemini 3 Pro High", CreatedAt: "2025-06-01T00:00:00Z"},
+ {ID: "gemini-3.1-pro-low", DisplayName: "Gemini 3.1 Pro Low", CreatedAt: "2026-02-19T00:00:00Z"},
+ {ID: "gemini-3.1-pro-high", DisplayName: "Gemini 3.1 Pro High", CreatedAt: "2026-02-19T00:00:00Z"},
+ {ID: "gemini-3.1-flash-image", DisplayName: "Gemini 3.1 Flash Image", CreatedAt: "2026-02-19T00:00:00Z"},
+ {ID: "gemini-3.1-flash-image-preview", DisplayName: "Gemini 3.1 Flash Image Preview", CreatedAt: "2026-02-19T00:00:00Z"},
{ID: "gemini-3-pro-preview", DisplayName: "Gemini 3 Pro Preview", CreatedAt: "2025-06-01T00:00:00Z"},
{ID: "gemini-3-pro-image", DisplayName: "Gemini 3 Pro Image", CreatedAt: "2025-06-01T00:00:00Z"},
}
diff --git a/backend/internal/pkg/antigravity/claude_types_test.go b/backend/internal/pkg/antigravity/claude_types_test.go
new file mode 100644
index 00000000..f7cb0a24
--- /dev/null
+++ b/backend/internal/pkg/antigravity/claude_types_test.go
@@ -0,0 +1,26 @@
+package antigravity
+
+import "testing"
+
+func TestDefaultModels_ContainsNewAndLegacyImageModels(t *testing.T) {
+ t.Parallel()
+
+ models := DefaultModels()
+ byID := make(map[string]ClaudeModel, len(models))
+ for _, m := range models {
+ byID[m.ID] = m
+ }
+
+ requiredIDs := []string{
+ "claude-opus-4-6-thinking",
+ "gemini-3.1-flash-image",
+ "gemini-3.1-flash-image-preview",
+ "gemini-3-pro-image", // legacy compatibility
+ }
+
+ for _, id := range requiredIDs {
+ if _, ok := byID[id]; !ok {
+ t.Fatalf("expected model %q to be exposed in DefaultModels", id)
+ }
+ }
+}
diff --git a/backend/internal/pkg/antigravity/gemini_types.go b/backend/internal/pkg/antigravity/gemini_types.go
index 32495827..0ff24a1f 100644
--- a/backend/internal/pkg/antigravity/gemini_types.go
+++ b/backend/internal/pkg/antigravity/gemini_types.go
@@ -70,7 +70,7 @@ type GeminiGenerationConfig struct {
ImageConfig *GeminiImageConfig `json:"imageConfig,omitempty"`
}
-// GeminiImageConfig Gemini 图片生成配置(仅 gemini-3-pro-image 支持)
+// GeminiImageConfig Gemini 图片生成配置(gemini-3-pro-image / gemini-3.1-flash-image 等图片模型支持)
type GeminiImageConfig struct {
AspectRatio string `json:"aspectRatio,omitempty"` // "1:1", "16:9", "9:16", "4:3", "3:4"
ImageSize string `json:"imageSize,omitempty"` // "1K", "2K", "4K"
diff --git a/backend/internal/pkg/antigravity/oauth.go b/backend/internal/pkg/antigravity/oauth.go
index 47c75142..18310655 100644
--- a/backend/internal/pkg/antigravity/oauth.go
+++ b/backend/internal/pkg/antigravity/oauth.go
@@ -49,11 +49,12 @@ const (
antigravityDailyBaseURL = "https://daily-cloudcode-pa.sandbox.googleapis.com"
)
-// defaultUserAgentVersion 可通过环境变量 ANTIGRAVITY_USER_AGENT_VERSION 配置,默认 1.18.4
-var defaultUserAgentVersion = "1.18.4"
+// defaultUserAgentVersion 可通过环境变量 ANTIGRAVITY_USER_AGENT_VERSION 配置,默认 1.19.6
+var defaultUserAgentVersion = "1.19.6"
// defaultClientSecret 可通过环境变量 ANTIGRAVITY_OAUTH_CLIENT_SECRET 配置
-var defaultClientSecret = "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"
+// 默认值使用占位符,生产环境请通过环境变量注入真实值。
+var defaultClientSecret = "GOCSPX-your-client-secret"
func init() {
// 从环境变量读取版本号,未设置则使用默认值
diff --git a/backend/internal/pkg/antigravity/oauth_test.go b/backend/internal/pkg/antigravity/oauth_test.go
index 351708a5..2a2a52e9 100644
--- a/backend/internal/pkg/antigravity/oauth_test.go
+++ b/backend/internal/pkg/antigravity/oauth_test.go
@@ -612,14 +612,14 @@ func TestBuildAuthorizationURL_参数验证(t *testing.T) {
expectedParams := map[string]string{
"client_id": ClientID,
- "redirect_uri": RedirectURI,
- "response_type": "code",
- "scope": Scopes,
- "state": state,
- "code_challenge": codeChallenge,
- "code_challenge_method": "S256",
- "access_type": "offline",
- "prompt": "consent",
+ "redirect_uri": RedirectURI,
+ "response_type": "code",
+ "scope": Scopes,
+ "state": state,
+ "code_challenge": codeChallenge,
+ "code_challenge_method": "S256",
+ "access_type": "offline",
+ "prompt": "consent",
"include_granted_scopes": "true",
}
@@ -684,13 +684,13 @@ func TestConstants_值正确(t *testing.T) {
if err != nil {
t.Fatalf("getClientSecret 应返回默认值,但报错: %v", err)
}
- if secret != "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" {
+ if secret != "GOCSPX-your-client-secret" {
t.Errorf("默认 client_secret 不匹配: got %s", secret)
}
if RedirectURI != "http://localhost:8085/callback" {
t.Errorf("RedirectURI 不匹配: got %s", RedirectURI)
}
- if GetUserAgent() != "antigravity/1.18.4 windows/amd64" {
+ if GetUserAgent() != "antigravity/1.19.6 windows/amd64" {
t.Errorf("UserAgent 不匹配: got %s", GetUserAgent())
}
if SessionTTL != 30*time.Minute {
diff --git a/backend/internal/pkg/claude/constants.go b/backend/internal/pkg/claude/constants.go
index 423ad925..22405382 100644
--- a/backend/internal/pkg/claude/constants.go
+++ b/backend/internal/pkg/claude/constants.go
@@ -11,8 +11,13 @@ const (
BetaFineGrainedToolStreaming = "fine-grained-tool-streaming-2025-05-14"
BetaTokenCounting = "token-counting-2024-11-01"
BetaContext1M = "context-1m-2025-08-07"
+ BetaFastMode = "fast-mode-2026-02-01"
)
+// DroppedBetas 是转发时需要从 anthropic-beta header 中移除的 beta token 列表。
+// 这些 token 是客户端特有的,不应透传给上游 API。
+var DroppedBetas = []string{BetaContext1M, BetaFastMode}
+
// DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header
const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming
diff --git a/backend/internal/pkg/ctxkey/ctxkey.go b/backend/internal/pkg/ctxkey/ctxkey.go
index b13d66cb..25782c55 100644
--- a/backend/internal/pkg/ctxkey/ctxkey.go
+++ b/backend/internal/pkg/ctxkey/ctxkey.go
@@ -52,4 +52,7 @@ const (
// PrefetchedStickyGroupID 标识上游预取 sticky session 时所使用的分组 ID。
// Service 层仅在分组匹配时复用 PrefetchedStickyAccountID,避免分组切换重试误用旧 sticky。
PrefetchedStickyGroupID Key = "ctx_prefetched_sticky_group_id"
+
+ // ClaudeCodeVersion stores the extracted Claude Code version from User-Agent (e.g. "2.1.22")
+ ClaudeCodeVersion Key = "ctx_claude_code_version"
)
diff --git a/backend/internal/pkg/errors/errors_test.go b/backend/internal/pkg/errors/errors_test.go
index 1a1c842e..25e62907 100644
--- a/backend/internal/pkg/errors/errors_test.go
+++ b/backend/internal/pkg/errors/errors_test.go
@@ -166,3 +166,18 @@ func TestToHTTP(t *testing.T) {
})
}
}
+
+func TestToHTTP_MetadataDeepCopy(t *testing.T) {
+ md := map[string]string{"k": "v"}
+ appErr := BadRequest("BAD_REQUEST", "invalid").WithMetadata(md)
+
+ code, body := ToHTTP(appErr)
+ require.Equal(t, http.StatusBadRequest, code)
+ require.Equal(t, "v", body.Metadata["k"])
+
+ md["k"] = "changed"
+ require.Equal(t, "v", body.Metadata["k"])
+
+ appErr.Metadata["k"] = "changed-again"
+ require.Equal(t, "v", body.Metadata["k"])
+}
diff --git a/backend/internal/pkg/errors/http.go b/backend/internal/pkg/errors/http.go
index 7b5560e3..420c69a3 100644
--- a/backend/internal/pkg/errors/http.go
+++ b/backend/internal/pkg/errors/http.go
@@ -16,6 +16,16 @@ func ToHTTP(err error) (statusCode int, body Status) {
return http.StatusOK, Status{Code: int32(http.StatusOK)}
}
- cloned := Clone(appErr)
- return int(cloned.Code), cloned.Status
+ body = Status{
+ Code: appErr.Code,
+ Reason: appErr.Reason,
+ Message: appErr.Message,
+ }
+ if appErr.Metadata != nil {
+ body.Metadata = make(map[string]string, len(appErr.Metadata))
+ for k, v := range appErr.Metadata {
+ body.Metadata[k] = v
+ }
+ }
+ return int(appErr.Code), body
}
diff --git a/backend/internal/pkg/geminicli/constants.go b/backend/internal/pkg/geminicli/constants.go
index 97234ffd..f5ee5735 100644
--- a/backend/internal/pkg/geminicli/constants.go
+++ b/backend/internal/pkg/geminicli/constants.go
@@ -39,7 +39,7 @@ const (
// They enable the "login without creating your own OAuth client" experience, but Google may
// restrict which scopes are allowed for this client.
GeminiCLIOAuthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
- GeminiCLIOAuthClientSecret = "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"
+ GeminiCLIOAuthClientSecret = "GOCSPX-your-client-secret"
// GeminiCLIOAuthClientSecretEnv is the environment variable name for the built-in client secret.
GeminiCLIOAuthClientSecretEnv = "GEMINI_CLI_OAUTH_CLIENT_SECRET"
diff --git a/backend/internal/pkg/httpclient/pool.go b/backend/internal/pkg/httpclient/pool.go
index 76b7aa91..6ef3d714 100644
--- a/backend/internal/pkg/httpclient/pool.go
+++ b/backend/internal/pkg/httpclient/pool.go
@@ -32,6 +32,7 @@ const (
defaultMaxIdleConns = 100 // 最大空闲连接数
defaultMaxIdleConnsPerHost = 10 // 每个主机最大空闲连接数
defaultIdleConnTimeout = 90 * time.Second // 空闲连接超时时间(建议小于上游 LB 超时)
+ validatedHostTTL = 30 * time.Second // DNS Rebinding 校验缓存 TTL
)
// Options 定义共享 HTTP 客户端的构建参数
@@ -53,6 +54,9 @@ type Options struct {
// sharedClients 存储按配置参数缓存的 http.Client 实例
var sharedClients sync.Map
+// 允许测试替换校验函数,生产默认指向真实实现。
+var validateResolvedIP = urlvalidator.ValidateResolvedIP
+
// GetClient 返回共享的 HTTP 客户端实例
// 性能优化:相同配置复用同一客户端,避免重复创建 Transport
// 安全说明:代理配置失败时直接返回错误,不会回退到直连,避免 IP 关联风险
@@ -84,7 +88,7 @@ func buildClient(opts Options) (*http.Client, error) {
var rt http.RoundTripper = transport
if opts.ValidateResolvedIP && !opts.AllowPrivateHosts {
- rt = &validatedTransport{base: transport}
+ rt = newValidatedTransport(transport)
}
return &http.Client{
Transport: rt,
@@ -149,17 +153,56 @@ func buildClientKey(opts Options) string {
}
type validatedTransport struct {
- base http.RoundTripper
+ base http.RoundTripper
+ validatedHosts sync.Map // map[string]time.Time, value 为过期时间
+ now func() time.Time
+}
+
+func newValidatedTransport(base http.RoundTripper) *validatedTransport {
+ return &validatedTransport{
+ base: base,
+ now: time.Now,
+ }
+}
+
+func (t *validatedTransport) isValidatedHost(host string, now time.Time) bool {
+ if t == nil {
+ return false
+ }
+ raw, ok := t.validatedHosts.Load(host)
+ if !ok {
+ return false
+ }
+ expireAt, ok := raw.(time.Time)
+ if !ok {
+ t.validatedHosts.Delete(host)
+ return false
+ }
+ if now.Before(expireAt) {
+ return true
+ }
+ t.validatedHosts.Delete(host)
+ return false
}
func (t *validatedTransport) RoundTrip(req *http.Request) (*http.Response, error) {
if req != nil && req.URL != nil {
- host := strings.TrimSpace(req.URL.Hostname())
+ host := strings.ToLower(strings.TrimSpace(req.URL.Hostname()))
if host != "" {
- if err := urlvalidator.ValidateResolvedIP(host); err != nil {
- return nil, err
+ now := time.Now()
+ if t != nil && t.now != nil {
+ now = t.now()
+ }
+ if !t.isValidatedHost(host, now) {
+ if err := validateResolvedIP(host); err != nil {
+ return nil, err
+ }
+ t.validatedHosts.Store(host, now.Add(validatedHostTTL))
}
}
}
+ if t == nil || t.base == nil {
+ return nil, fmt.Errorf("validated transport base is nil")
+ }
return t.base.RoundTrip(req)
}
diff --git a/backend/internal/pkg/httpclient/pool_test.go b/backend/internal/pkg/httpclient/pool_test.go
new file mode 100644
index 00000000..f945758a
--- /dev/null
+++ b/backend/internal/pkg/httpclient/pool_test.go
@@ -0,0 +1,115 @@
+package httpclient
+
+import (
+ "errors"
+ "io"
+ "net/http"
+ "strings"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+)
+
+type roundTripFunc func(*http.Request) (*http.Response, error)
+
+func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
+ return f(req)
+}
+
+func TestValidatedTransport_CacheHostValidation(t *testing.T) {
+ originalValidate := validateResolvedIP
+ defer func() { validateResolvedIP = originalValidate }()
+
+ var validateCalls int32
+ validateResolvedIP = func(host string) error {
+ atomic.AddInt32(&validateCalls, 1)
+ require.Equal(t, "api.openai.com", host)
+ return nil
+ }
+
+ var baseCalls int32
+ base := roundTripFunc(func(_ *http.Request) (*http.Response, error) {
+ atomic.AddInt32(&baseCalls, 1)
+ return &http.Response{
+ StatusCode: http.StatusOK,
+ Body: io.NopCloser(strings.NewReader(`{}`)),
+ Header: make(http.Header),
+ }, nil
+ })
+
+ now := time.Unix(1730000000, 0)
+ transport := newValidatedTransport(base)
+ transport.now = func() time.Time { return now }
+
+ req, err := http.NewRequest(http.MethodGet, "https://api.openai.com/v1/responses", nil)
+ require.NoError(t, err)
+
+ _, err = transport.RoundTrip(req)
+ require.NoError(t, err)
+ _, err = transport.RoundTrip(req)
+ require.NoError(t, err)
+
+ require.Equal(t, int32(1), atomic.LoadInt32(&validateCalls))
+ require.Equal(t, int32(2), atomic.LoadInt32(&baseCalls))
+}
+
+func TestValidatedTransport_ExpiredCacheTriggersRevalidation(t *testing.T) {
+ originalValidate := validateResolvedIP
+ defer func() { validateResolvedIP = originalValidate }()
+
+ var validateCalls int32
+ validateResolvedIP = func(_ string) error {
+ atomic.AddInt32(&validateCalls, 1)
+ return nil
+ }
+
+ base := roundTripFunc(func(_ *http.Request) (*http.Response, error) {
+ return &http.Response{
+ StatusCode: http.StatusOK,
+ Body: io.NopCloser(strings.NewReader(`{}`)),
+ Header: make(http.Header),
+ }, nil
+ })
+
+ now := time.Unix(1730001000, 0)
+ transport := newValidatedTransport(base)
+ transport.now = func() time.Time { return now }
+
+ req, err := http.NewRequest(http.MethodGet, "https://api.openai.com/v1/responses", nil)
+ require.NoError(t, err)
+
+ _, err = transport.RoundTrip(req)
+ require.NoError(t, err)
+
+ now = now.Add(validatedHostTTL + time.Second)
+ _, err = transport.RoundTrip(req)
+ require.NoError(t, err)
+
+ require.Equal(t, int32(2), atomic.LoadInt32(&validateCalls))
+}
+
+func TestValidatedTransport_ValidationErrorStopsRoundTrip(t *testing.T) {
+ originalValidate := validateResolvedIP
+ defer func() { validateResolvedIP = originalValidate }()
+
+ expectedErr := errors.New("dns rebinding rejected")
+ validateResolvedIP = func(_ string) error {
+ return expectedErr
+ }
+
+ var baseCalls int32
+ base := roundTripFunc(func(_ *http.Request) (*http.Response, error) {
+ atomic.AddInt32(&baseCalls, 1)
+ return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(strings.NewReader(`{}`))}, nil
+ })
+
+ transport := newValidatedTransport(base)
+ req, err := http.NewRequest(http.MethodGet, "https://api.openai.com/v1/responses", nil)
+ require.NoError(t, err)
+
+ _, err = transport.RoundTrip(req)
+ require.ErrorIs(t, err, expectedErr)
+ require.Equal(t, int32(0), atomic.LoadInt32(&baseCalls))
+}
diff --git a/backend/internal/pkg/httputil/body.go b/backend/internal/pkg/httputil/body.go
new file mode 100644
index 00000000..69e99dc5
--- /dev/null
+++ b/backend/internal/pkg/httputil/body.go
@@ -0,0 +1,37 @@
+package httputil
+
+import (
+ "bytes"
+ "io"
+ "net/http"
+)
+
+const (
+ requestBodyReadInitCap = 512
+ requestBodyReadMaxInitCap = 1 << 20
+)
+
+// ReadRequestBodyWithPrealloc reads request body with preallocated buffer based on content length.
+func ReadRequestBodyWithPrealloc(req *http.Request) ([]byte, error) {
+ if req == nil || req.Body == nil {
+ return nil, nil
+ }
+
+ capHint := requestBodyReadInitCap
+ if req.ContentLength > 0 {
+ switch {
+ case req.ContentLength < int64(requestBodyReadInitCap):
+ capHint = requestBodyReadInitCap
+ case req.ContentLength > int64(requestBodyReadMaxInitCap):
+ capHint = requestBodyReadMaxInitCap
+ default:
+ capHint = int(req.ContentLength)
+ }
+ }
+
+ buf := bytes.NewBuffer(make([]byte, 0, capHint))
+ if _, err := io.Copy(buf, req.Body); err != nil {
+ return nil, err
+ }
+ return buf.Bytes(), nil
+}
diff --git a/backend/internal/pkg/ip/ip.go b/backend/internal/pkg/ip/ip.go
index 3f05ac41..f6f77c86 100644
--- a/backend/internal/pkg/ip/ip.go
+++ b/backend/internal/pkg/ip/ip.go
@@ -67,6 +67,14 @@ func normalizeIP(ip string) string {
// privateNets 预编译私有 IP CIDR 块,避免每次调用 isPrivateIP 时重复解析
var privateNets []*net.IPNet
+// CompiledIPRules 表示预编译的 IP 匹配规则。
+// PatternCount 记录原始规则数量,用于保留“规则存在但全无效”时的行为语义。
+type CompiledIPRules struct {
+ CIDRs []*net.IPNet
+ IPs []net.IP
+ PatternCount int
+}
+
func init() {
for _, cidr := range []string{
"10.0.0.0/8",
@@ -84,6 +92,53 @@ func init() {
}
}
+// CompileIPRules 将 IP/CIDR 字符串规则预编译为可复用结构。
+// 非法规则会被忽略,但 PatternCount 会保留原始规则条数。
+func CompileIPRules(patterns []string) *CompiledIPRules {
+ compiled := &CompiledIPRules{
+ CIDRs: make([]*net.IPNet, 0, len(patterns)),
+ IPs: make([]net.IP, 0, len(patterns)),
+ PatternCount: len(patterns),
+ }
+ for _, pattern := range patterns {
+ normalized := strings.TrimSpace(pattern)
+ if normalized == "" {
+ continue
+ }
+ if strings.Contains(normalized, "/") {
+ _, cidr, err := net.ParseCIDR(normalized)
+ if err != nil || cidr == nil {
+ continue
+ }
+ compiled.CIDRs = append(compiled.CIDRs, cidr)
+ continue
+ }
+ parsedIP := net.ParseIP(normalized)
+ if parsedIP == nil {
+ continue
+ }
+ compiled.IPs = append(compiled.IPs, parsedIP)
+ }
+ return compiled
+}
+
+func matchesCompiledRules(parsedIP net.IP, rules *CompiledIPRules) bool {
+ if parsedIP == nil || rules == nil {
+ return false
+ }
+ for _, cidr := range rules.CIDRs {
+ if cidr.Contains(parsedIP) {
+ return true
+ }
+ }
+ for _, ruleIP := range rules.IPs {
+ if parsedIP.Equal(ruleIP) {
+ return true
+ }
+ }
+ return false
+}
+
// isPrivateIP 检查 IP 是否为私有地址。
func isPrivateIP(ipStr string) bool {
ip := net.ParseIP(ipStr)
@@ -142,19 +197,32 @@ func MatchesAnyPattern(clientIP string, patterns []string) bool {
// 2. 如果白名单不为空,IP 必须在白名单中
// 3. 如果白名单为空,允许访问(除非被黑名单拒绝)
func CheckIPRestriction(clientIP string, whitelist, blacklist []string) (bool, string) {
+ return CheckIPRestrictionWithCompiledRules(
+ clientIP,
+ CompileIPRules(whitelist),
+ CompileIPRules(blacklist),
+ )
+}
+
+// CheckIPRestrictionWithCompiledRules 使用预编译规则检查 IP 是否允许访问。
+func CheckIPRestrictionWithCompiledRules(clientIP string, whitelist, blacklist *CompiledIPRules) (bool, string) {
// 规范化 IP
clientIP = normalizeIP(clientIP)
if clientIP == "" {
return false, "access denied"
}
+ parsedIP := net.ParseIP(clientIP)
+ if parsedIP == nil {
+ return false, "access denied"
+ }
// 1. 检查黑名单
- if len(blacklist) > 0 && MatchesAnyPattern(clientIP, blacklist) {
+ if blacklist != nil && blacklist.PatternCount > 0 && matchesCompiledRules(parsedIP, blacklist) {
return false, "access denied"
}
// 2. 检查白名单(如果设置了白名单,IP 必须在其中)
- if len(whitelist) > 0 && !MatchesAnyPattern(clientIP, whitelist) {
+ if whitelist != nil && whitelist.PatternCount > 0 && !matchesCompiledRules(parsedIP, whitelist) {
return false, "access denied"
}
diff --git a/backend/internal/pkg/ip/ip_test.go b/backend/internal/pkg/ip/ip_test.go
index 3839403c..403b2d59 100644
--- a/backend/internal/pkg/ip/ip_test.go
+++ b/backend/internal/pkg/ip/ip_test.go
@@ -73,3 +73,24 @@ func TestGetTrustedClientIPUsesGinClientIP(t *testing.T) {
require.Equal(t, 200, w.Code)
require.Equal(t, "9.9.9.9", w.Body.String())
}
+
+func TestCheckIPRestrictionWithCompiledRules(t *testing.T) {
+ whitelist := CompileIPRules([]string{"10.0.0.0/8", "192.168.1.2"})
+ blacklist := CompileIPRules([]string{"10.1.1.1"})
+
+ allowed, reason := CheckIPRestrictionWithCompiledRules("10.2.3.4", whitelist, blacklist)
+ require.True(t, allowed)
+ require.Equal(t, "", reason)
+
+ allowed, reason = CheckIPRestrictionWithCompiledRules("10.1.1.1", whitelist, blacklist)
+ require.False(t, allowed)
+ require.Equal(t, "access denied", reason)
+}
+
+func TestCheckIPRestrictionWithCompiledRules_InvalidWhitelistStillDenies(t *testing.T) {
+ // 与旧实现保持一致:白名单有配置但全无效时,最终应拒绝访问。
+ invalidWhitelist := CompileIPRules([]string{"not-a-valid-pattern"})
+ allowed, reason := CheckIPRestrictionWithCompiledRules("8.8.8.8", invalidWhitelist, nil)
+ require.False(t, allowed)
+ require.Equal(t, "access denied", reason)
+}
diff --git a/backend/internal/pkg/logger/logger.go b/backend/internal/pkg/logger/logger.go
index 80d92517..3fca706e 100644
--- a/backend/internal/pkg/logger/logger.go
+++ b/backend/internal/pkg/logger/logger.go
@@ -10,6 +10,7 @@ import (
"path/filepath"
"strings"
"sync"
+ "sync/atomic"
"time"
"go.uber.org/zap"
@@ -42,15 +43,19 @@ type LogEvent struct {
var (
mu sync.RWMutex
- global *zap.Logger
- sugar *zap.SugaredLogger
+ global atomic.Pointer[zap.Logger]
+ sugar atomic.Pointer[zap.SugaredLogger]
atomicLevel zap.AtomicLevel
initOptions InitOptions
- currentSink Sink
+ currentSink atomic.Value // sinkState
stdLogUndo func()
bootstrapOnce sync.Once
)
+type sinkState struct {
+ sink Sink
+}
+
func InitBootstrap() {
bootstrapOnce.Do(func() {
if err := Init(bootstrapOptions()); err != nil {
@@ -72,9 +77,9 @@ func initLocked(options InitOptions) error {
return err
}
- prev := global
- global = zl
- sugar = zl.Sugar()
+ prev := global.Load()
+ global.Store(zl)
+ sugar.Store(zl.Sugar())
atomicLevel = al
initOptions = normalized
@@ -115,24 +120,32 @@ func SetLevel(level string) error {
func CurrentLevel() string {
mu.RLock()
defer mu.RUnlock()
- if global == nil {
+ if global.Load() == nil {
return "info"
}
return atomicLevel.Level().String()
}
func SetSink(sink Sink) {
- mu.Lock()
- defer mu.Unlock()
- currentSink = sink
+ currentSink.Store(sinkState{sink: sink})
+}
+
+func loadSink() Sink {
+ v := currentSink.Load()
+ if v == nil {
+ return nil
+ }
+ state, ok := v.(sinkState)
+ if !ok {
+ return nil
+ }
+ return state.sink
}
// WriteSinkEvent 直接写入日志 sink,不经过全局日志级别门控。
// 用于需要“可观测性入库”与“业务输出级别”解耦的场景(例如 ops 系统日志索引)。
func WriteSinkEvent(level, component, message string, fields map[string]any) {
- mu.RLock()
- sink := currentSink
- mu.RUnlock()
+ sink := loadSink()
if sink == nil {
return
}
@@ -168,19 +181,15 @@ func WriteSinkEvent(level, component, message string, fields map[string]any) {
}
func L() *zap.Logger {
- mu.RLock()
- defer mu.RUnlock()
- if global != nil {
- return global
+ if l := global.Load(); l != nil {
+ return l
}
return zap.NewNop()
}
func S() *zap.SugaredLogger {
- mu.RLock()
- defer mu.RUnlock()
- if sugar != nil {
- return sugar
+ if s := sugar.Load(); s != nil {
+ return s
}
return zap.NewNop().Sugar()
}
@@ -190,9 +199,7 @@ func With(fields ...zap.Field) *zap.Logger {
}
func Sync() {
- mu.RLock()
- l := global
- mu.RUnlock()
+ l := global.Load()
if l != nil {
_ = l.Sync()
}
@@ -210,7 +217,11 @@ func bridgeStdLogLocked() {
log.SetFlags(0)
log.SetPrefix("")
- log.SetOutput(newStdLogBridge(global.Named("stdlog")))
+ base := global.Load()
+ if base == nil {
+ base = zap.NewNop()
+ }
+ log.SetOutput(newStdLogBridge(base.Named("stdlog")))
stdLogUndo = func() {
log.SetOutput(prevWriter)
@@ -220,7 +231,11 @@ func bridgeStdLogLocked() {
}
func bridgeSlogLocked() {
- slog.SetDefault(slog.New(newSlogZapHandler(global.Named("slog"))))
+ base := global.Load()
+ if base == nil {
+ base = zap.NewNop()
+ }
+ slog.SetDefault(slog.New(newSlogZapHandler(base.Named("slog"))))
}
func buildLogger(options InitOptions) (*zap.Logger, zap.AtomicLevel, error) {
@@ -363,9 +378,7 @@ func (s *sinkCore) Check(entry zapcore.Entry, ce *zapcore.CheckedEntry) *zapcore
func (s *sinkCore) Write(entry zapcore.Entry, fields []zapcore.Field) error {
// Only handle sink forwarding — the inner cores write via their own
// Write methods (added to CheckedEntry by s.core.Check above).
- mu.RLock()
- sink := currentSink
- mu.RUnlock()
+ sink := loadSink()
if sink == nil {
return nil
}
@@ -454,7 +467,7 @@ func inferStdLogLevel(msg string) Level {
if strings.Contains(lower, " failed") || strings.Contains(lower, "error") || strings.Contains(lower, "panic") || strings.Contains(lower, "fatal") {
return LevelError
}
- if strings.Contains(lower, "warning") || strings.Contains(lower, "warn") || strings.Contains(lower, " retry") || strings.Contains(lower, " queue full") || strings.Contains(lower, "fallback") {
+ if strings.Contains(lower, "warning") || strings.Contains(lower, "warn") || strings.Contains(lower, " queue full") || strings.Contains(lower, "fallback") {
return LevelWarn
}
return LevelInfo
@@ -467,9 +480,7 @@ func LegacyPrintf(component, format string, args ...any) {
return
}
- mu.RLock()
- initialized := global != nil
- mu.RUnlock()
+ initialized := global.Load() != nil
if !initialized {
// 在日志系统未初始化前,回退到标准库 log,避免测试/工具链丢日志。
log.Print(msg)
diff --git a/backend/internal/pkg/logger/slog_handler.go b/backend/internal/pkg/logger/slog_handler.go
index 562b8341..602ca1e0 100644
--- a/backend/internal/pkg/logger/slog_handler.go
+++ b/backend/internal/pkg/logger/slog_handler.go
@@ -48,16 +48,15 @@ func (h *slogZapHandler) Handle(_ context.Context, record slog.Record) error {
return true
})
- entry := h.logger.With(fields...)
switch {
case record.Level >= slog.LevelError:
- entry.Error(record.Message)
+ h.logger.Error(record.Message, fields...)
case record.Level >= slog.LevelWarn:
- entry.Warn(record.Message)
+ h.logger.Warn(record.Message, fields...)
case record.Level <= slog.LevelDebug:
- entry.Debug(record.Message)
+ h.logger.Debug(record.Message, fields...)
default:
- entry.Info(record.Message)
+ h.logger.Info(record.Message, fields...)
}
return nil
}
diff --git a/backend/internal/pkg/logger/stdlog_bridge_test.go b/backend/internal/pkg/logger/stdlog_bridge_test.go
index a3f76fd7..4482a2ec 100644
--- a/backend/internal/pkg/logger/stdlog_bridge_test.go
+++ b/backend/internal/pkg/logger/stdlog_bridge_test.go
@@ -16,6 +16,7 @@ func TestInferStdLogLevel(t *testing.T) {
{msg: "Warning: queue full", want: LevelWarn},
{msg: "Forward request failed: timeout", want: LevelError},
{msg: "[ERROR] upstream unavailable", want: LevelError},
+ {msg: "[OpenAI WS Mode] reconnect_retry account_id=22 retry=1 max_retries=5", want: LevelInfo},
{msg: "service started", want: LevelInfo},
{msg: "debug: cache miss", want: LevelDebug},
}
diff --git a/backend/internal/pkg/openai/oauth.go b/backend/internal/pkg/openai/oauth.go
index e3b931be..8bdcbe16 100644
--- a/backend/internal/pkg/openai/oauth.go
+++ b/backend/internal/pkg/openai/oauth.go
@@ -36,10 +36,18 @@ const (
SessionTTL = 30 * time.Minute
)
+const (
+ // OAuthPlatformOpenAI uses OpenAI Codex-compatible OAuth client.
+ OAuthPlatformOpenAI = "openai"
+ // OAuthPlatformSora uses Sora OAuth client.
+ OAuthPlatformSora = "sora"
+)
+
// OAuthSession stores OAuth flow state for OpenAI
type OAuthSession struct {
State string `json:"state"`
CodeVerifier string `json:"code_verifier"`
+ ClientID string `json:"client_id,omitempty"`
ProxyURL string `json:"proxy_url,omitempty"`
RedirectURI string `json:"redirect_uri"`
CreatedAt time.Time `json:"created_at"`
@@ -174,13 +182,20 @@ func base64URLEncode(data []byte) string {
// BuildAuthorizationURL builds the OpenAI OAuth authorization URL
func BuildAuthorizationURL(state, codeChallenge, redirectURI string) string {
+ return BuildAuthorizationURLForPlatform(state, codeChallenge, redirectURI, OAuthPlatformOpenAI)
+}
+
+// BuildAuthorizationURLForPlatform builds authorization URL by platform.
+func BuildAuthorizationURLForPlatform(state, codeChallenge, redirectURI, platform string) string {
if redirectURI == "" {
redirectURI = DefaultRedirectURI
}
+ clientID, codexFlow := OAuthClientConfigByPlatform(platform)
+
params := url.Values{}
params.Set("response_type", "code")
- params.Set("client_id", ClientID)
+ params.Set("client_id", clientID)
params.Set("redirect_uri", redirectURI)
params.Set("scope", DefaultScopes)
params.Set("state", state)
@@ -188,11 +203,25 @@ func BuildAuthorizationURL(state, codeChallenge, redirectURI string) string {
params.Set("code_challenge_method", "S256")
// OpenAI specific parameters
params.Set("id_token_add_organizations", "true")
- params.Set("codex_cli_simplified_flow", "true")
+ if codexFlow {
+ params.Set("codex_cli_simplified_flow", "true")
+ }
return fmt.Sprintf("%s?%s", AuthorizeURL, params.Encode())
}
+// OAuthClientConfigByPlatform returns oauth client_id and whether codex simplified flow should be enabled.
+// Sora 授权流程复用 Codex CLI 的 client_id(支持 localhost redirect_uri),
+// 但不启用 codex_cli_simplified_flow;拿到的 access_token 绑定同一 OpenAI 账号,对 Sora API 同样可用。
+func OAuthClientConfigByPlatform(platform string) (clientID string, codexFlow bool) {
+ switch strings.ToLower(strings.TrimSpace(platform)) {
+ case OAuthPlatformSora:
+ return ClientID, false
+ default:
+ return ClientID, true
+ }
+}
+
// TokenRequest represents the token exchange request body
type TokenRequest struct {
GrantType string `json:"grant_type"`
@@ -296,9 +325,11 @@ func (r *RefreshTokenRequest) ToFormData() string {
return params.Encode()
}
-// ParseIDToken parses the ID Token JWT and extracts claims
-// Note: This does NOT verify the signature - it only decodes the payload
-// For production, you should verify the token signature using OpenAI's public keys
+// ParseIDToken parses the ID Token JWT and extracts claims.
+// 注意:当前仅解码 payload 并校验 exp,未验证 JWT 签名。
+// 生产环境如需用 ID Token 做授权决策,应通过 OpenAI 的 JWKS 端点验证签名:
+//
+// https://auth.openai.com/.well-known/jwks.json
func ParseIDToken(idToken string) (*IDTokenClaims, error) {
parts := strings.Split(idToken, ".")
if len(parts) != 3 {
@@ -329,6 +360,13 @@ func ParseIDToken(idToken string) (*IDTokenClaims, error) {
return nil, fmt.Errorf("failed to parse JWT claims: %w", err)
}
+ // 校验 ID Token 是否已过期(允许 2 分钟时钟偏差,防止因服务器时钟略有差异误判刚颁发的令牌)
+ const clockSkewTolerance = 120 // 秒
+ now := time.Now().Unix()
+ if claims.Exp > 0 && now > claims.Exp+clockSkewTolerance {
+ return nil, fmt.Errorf("id_token has expired (exp: %d, now: %d, skew_tolerance: %ds)", claims.Exp, now, clockSkewTolerance)
+ }
+
return &claims, nil
}
diff --git a/backend/internal/pkg/openai/oauth_test.go b/backend/internal/pkg/openai/oauth_test.go
index f1d616a6..2970addf 100644
--- a/backend/internal/pkg/openai/oauth_test.go
+++ b/backend/internal/pkg/openai/oauth_test.go
@@ -1,6 +1,7 @@
package openai
import (
+ "net/url"
"sync"
"testing"
"time"
@@ -41,3 +42,41 @@ func TestSessionStore_Stop_Concurrent(t *testing.T) {
t.Fatal("stopCh 未关闭")
}
}
+
+func TestBuildAuthorizationURLForPlatform_OpenAI(t *testing.T) {
+ authURL := BuildAuthorizationURLForPlatform("state-1", "challenge-1", DefaultRedirectURI, OAuthPlatformOpenAI)
+ parsed, err := url.Parse(authURL)
+ if err != nil {
+ t.Fatalf("Parse URL failed: %v", err)
+ }
+ q := parsed.Query()
+ if got := q.Get("client_id"); got != ClientID {
+ t.Fatalf("client_id mismatch: got=%q want=%q", got, ClientID)
+ }
+ if got := q.Get("codex_cli_simplified_flow"); got != "true" {
+ t.Fatalf("codex flow mismatch: got=%q want=true", got)
+ }
+ if got := q.Get("id_token_add_organizations"); got != "true" {
+ t.Fatalf("id_token_add_organizations mismatch: got=%q want=true", got)
+ }
+}
+
+// TestBuildAuthorizationURLForPlatform_Sora 验证 Sora 平台复用 Codex CLI 的 client_id,
+// 但不启用 codex_cli_simplified_flow。
+func TestBuildAuthorizationURLForPlatform_Sora(t *testing.T) {
+ authURL := BuildAuthorizationURLForPlatform("state-2", "challenge-2", DefaultRedirectURI, OAuthPlatformSora)
+ parsed, err := url.Parse(authURL)
+ if err != nil {
+ t.Fatalf("Parse URL failed: %v", err)
+ }
+ q := parsed.Query()
+ if got := q.Get("client_id"); got != ClientID {
+ t.Fatalf("client_id mismatch: got=%q want=%q (Sora should reuse Codex CLI client_id)", got, ClientID)
+ }
+ if got := q.Get("codex_cli_simplified_flow"); got != "" {
+ t.Fatalf("codex flow should be empty for sora, got=%q", got)
+ }
+ if got := q.Get("id_token_add_organizations"); got != "true" {
+ t.Fatalf("id_token_add_organizations mismatch: got=%q want=true", got)
+ }
+}
diff --git a/backend/internal/pkg/response/response_test.go b/backend/internal/pkg/response/response_test.go
index 3c12f5f4..0debce5f 100644
--- a/backend/internal/pkg/response/response_test.go
+++ b/backend/internal/pkg/response/response_test.go
@@ -29,10 +29,10 @@ func parsePaginatedBody(t *testing.T, w *httptest.ResponseRecorder) (Response, P
t.Helper()
// 先用 raw json 解析,因为 Data 是 any 类型
var raw struct {
- Code int `json:"code"`
- Message string `json:"message"`
- Reason string `json:"reason,omitempty"`
- Data json.RawMessage `json:"data,omitempty"`
+ Code int `json:"code"`
+ Message string `json:"message"`
+ Reason string `json:"reason,omitempty"`
+ Data json.RawMessage `json:"data,omitempty"`
}
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &raw))
diff --git a/backend/internal/pkg/tlsfingerprint/dialer.go b/backend/internal/pkg/tlsfingerprint/dialer.go
index 992f8b0a..4f25a34a 100644
--- a/backend/internal/pkg/tlsfingerprint/dialer.go
+++ b/backend/internal/pkg/tlsfingerprint/dialer.go
@@ -268,8 +268,8 @@ func (d *SOCKS5ProxyDialer) DialTLSContext(ctx context.Context, network, addr st
"cipher_suites", len(spec.CipherSuites),
"extensions", len(spec.Extensions),
"compression_methods", spec.CompressionMethods,
- "tls_vers_max", fmt.Sprintf("0x%04x", spec.TLSVersMax),
- "tls_vers_min", fmt.Sprintf("0x%04x", spec.TLSVersMin))
+ "tls_vers_max", spec.TLSVersMax,
+ "tls_vers_min", spec.TLSVersMin)
if d.profile != nil {
slog.Debug("tls_fingerprint_socks5_using_profile", "name", d.profile.Name, "grease", d.profile.EnableGREASE)
@@ -294,8 +294,8 @@ func (d *SOCKS5ProxyDialer) DialTLSContext(ctx context.Context, network, addr st
state := tlsConn.ConnectionState()
slog.Debug("tls_fingerprint_socks5_handshake_success",
- "version", fmt.Sprintf("0x%04x", state.Version),
- "cipher_suite", fmt.Sprintf("0x%04x", state.CipherSuite),
+ "version", state.Version,
+ "cipher_suite", state.CipherSuite,
"alpn", state.NegotiatedProtocol)
return tlsConn, nil
@@ -404,8 +404,8 @@ func (d *HTTPProxyDialer) DialTLSContext(ctx context.Context, network, addr stri
state := tlsConn.ConnectionState()
slog.Debug("tls_fingerprint_http_proxy_handshake_success",
- "version", fmt.Sprintf("0x%04x", state.Version),
- "cipher_suite", fmt.Sprintf("0x%04x", state.CipherSuite),
+ "version", state.Version,
+ "cipher_suite", state.CipherSuite,
"alpn", state.NegotiatedProtocol)
return tlsConn, nil
@@ -470,8 +470,8 @@ func (d *Dialer) DialTLSContext(ctx context.Context, network, addr string) (net.
// Log successful handshake details
state := tlsConn.ConnectionState()
slog.Debug("tls_fingerprint_handshake_success",
- "version", fmt.Sprintf("0x%04x", state.Version),
- "cipher_suite", fmt.Sprintf("0x%04x", state.CipherSuite),
+ "version", state.Version,
+ "cipher_suite", state.CipherSuite,
"alpn", state.NegotiatedProtocol)
return tlsConn, nil
diff --git a/backend/internal/pkg/usagestats/usage_log_types.go b/backend/internal/pkg/usagestats/usage_log_types.go
index 2f6c7fe0..314a6d3c 100644
--- a/backend/internal/pkg/usagestats/usage_log_types.go
+++ b/backend/internal/pkg/usagestats/usage_log_types.go
@@ -78,6 +78,16 @@ type ModelStat struct {
ActualCost float64 `json:"actual_cost"` // 实际扣除
}
+// GroupStat represents usage statistics for a single group
+type GroupStat struct {
+ GroupID int64 `json:"group_id"`
+ GroupName string `json:"group_name"`
+ Requests int64 `json:"requests"`
+ TotalTokens int64 `json:"total_tokens"`
+ Cost float64 `json:"cost"` // 标准计费
+ ActualCost float64 `json:"actual_cost"` // 实际扣除
+}
+
// UserUsageTrendPoint represents user usage trend data point
type UserUsageTrendPoint struct {
Date string `json:"date"`
@@ -139,6 +149,7 @@ type UsageLogFilters struct {
AccountID int64
GroupID int64
Model string
+ RequestType *int16
Stream *bool
BillingType *int8
StartTime *time.Time
diff --git a/backend/internal/repository/account_repo.go b/backend/internal/repository/account_repo.go
index 3f77a57e..4aa74928 100644
--- a/backend/internal/repository/account_repo.go
+++ b/backend/internal/repository/account_repo.go
@@ -50,11 +50,6 @@ type accountRepository struct {
schedulerCache service.SchedulerCache
}
-type tempUnschedSnapshot struct {
- until *time.Time
- reason string
-}
-
// NewAccountRepository 创建账户仓储实例。
// 这是对外暴露的构造函数,返回接口类型以便于依赖注入。
func NewAccountRepository(client *dbent.Client, sqlDB *sql.DB, schedulerCache service.SchedulerCache) service.AccountRepository {
@@ -189,11 +184,6 @@ func (r *accountRepository) GetByIDs(ctx context.Context, ids []int64) ([]*servi
accountIDs = append(accountIDs, acc.ID)
}
- tempUnschedMap, err := r.loadTempUnschedStates(ctx, accountIDs)
- if err != nil {
- return nil, err
- }
-
groupsByAccount, groupIDsByAccount, accountGroupsByAccount, err := r.loadAccountGroups(ctx, accountIDs)
if err != nil {
return nil, err
@@ -220,10 +210,6 @@ func (r *accountRepository) GetByIDs(ctx context.Context, ids []int64) ([]*servi
if ags, ok := accountGroupsByAccount[entAcc.ID]; ok {
out.AccountGroups = ags
}
- if snap, ok := tempUnschedMap[entAcc.ID]; ok {
- out.TempUnschedulableUntil = snap.until
- out.TempUnschedulableReason = snap.reason
- }
outByID[entAcc.ID] = out
}
@@ -611,6 +597,43 @@ func (r *accountRepository) syncSchedulerAccountSnapshot(ctx context.Context, ac
}
}
+func (r *accountRepository) syncSchedulerAccountSnapshots(ctx context.Context, accountIDs []int64) {
+ if r == nil || r.schedulerCache == nil || len(accountIDs) == 0 {
+ return
+ }
+
+ uniqueIDs := make([]int64, 0, len(accountIDs))
+ seen := make(map[int64]struct{}, len(accountIDs))
+ for _, id := range accountIDs {
+ if id <= 0 {
+ continue
+ }
+ if _, exists := seen[id]; exists {
+ continue
+ }
+ seen[id] = struct{}{}
+ uniqueIDs = append(uniqueIDs, id)
+ }
+ if len(uniqueIDs) == 0 {
+ return
+ }
+
+ accounts, err := r.GetByIDs(ctx, uniqueIDs)
+ if err != nil {
+ logger.LegacyPrintf("repository.account", "[Scheduler] batch sync account snapshot read failed: count=%d err=%v", len(uniqueIDs), err)
+ return
+ }
+
+ for _, account := range accounts {
+ if account == nil {
+ continue
+ }
+ if err := r.schedulerCache.SetAccount(ctx, account); err != nil {
+ logger.LegacyPrintf("repository.account", "[Scheduler] batch sync account snapshot write failed: id=%d err=%v", account.ID, err)
+ }
+ }
+}
+
func (r *accountRepository) ClearError(ctx context.Context, id int64) error {
_, err := r.client.Account.Update().
Where(dbaccount.IDEQ(id)).
@@ -1197,9 +1220,7 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
shouldSync = true
}
if shouldSync {
- for _, id := range ids {
- r.syncSchedulerAccountSnapshot(ctx, id)
- }
+ r.syncSchedulerAccountSnapshots(ctx, ids)
}
}
return rows, nil
@@ -1291,10 +1312,6 @@ func (r *accountRepository) accountsToService(ctx context.Context, accounts []*d
if err != nil {
return nil, err
}
- tempUnschedMap, err := r.loadTempUnschedStates(ctx, accountIDs)
- if err != nil {
- return nil, err
- }
groupsByAccount, groupIDsByAccount, accountGroupsByAccount, err := r.loadAccountGroups(ctx, accountIDs)
if err != nil {
return nil, err
@@ -1320,10 +1337,6 @@ func (r *accountRepository) accountsToService(ctx context.Context, accounts []*d
if ags, ok := accountGroupsByAccount[acc.ID]; ok {
out.AccountGroups = ags
}
- if snap, ok := tempUnschedMap[acc.ID]; ok {
- out.TempUnschedulableUntil = snap.until
- out.TempUnschedulableReason = snap.reason
- }
outAccounts = append(outAccounts, *out)
}
@@ -1348,48 +1361,6 @@ func notExpiredPredicate(now time.Time) dbpredicate.Account {
)
}
-func (r *accountRepository) loadTempUnschedStates(ctx context.Context, accountIDs []int64) (map[int64]tempUnschedSnapshot, error) {
- out := make(map[int64]tempUnschedSnapshot)
- if len(accountIDs) == 0 {
- return out, nil
- }
-
- rows, err := r.sql.QueryContext(ctx, `
- SELECT id, temp_unschedulable_until, temp_unschedulable_reason
- FROM accounts
- WHERE id = ANY($1)
- `, pq.Array(accountIDs))
- if err != nil {
- return nil, err
- }
- defer func() { _ = rows.Close() }()
-
- for rows.Next() {
- var id int64
- var until sql.NullTime
- var reason sql.NullString
- if err := rows.Scan(&id, &until, &reason); err != nil {
- return nil, err
- }
- var untilPtr *time.Time
- if until.Valid {
- tmp := until.Time
- untilPtr = &tmp
- }
- if reason.Valid {
- out[id] = tempUnschedSnapshot{until: untilPtr, reason: reason.String}
- } else {
- out[id] = tempUnschedSnapshot{until: untilPtr, reason: ""}
- }
- }
-
- if err := rows.Err(); err != nil {
- return nil, err
- }
-
- return out, nil
-}
-
func (r *accountRepository) loadProxies(ctx context.Context, proxyIDs []int64) (map[int64]*service.Proxy, error) {
proxyMap := make(map[int64]*service.Proxy)
if len(proxyIDs) == 0 {
@@ -1500,31 +1471,33 @@ func accountEntityToService(m *dbent.Account) *service.Account {
rateMultiplier := m.RateMultiplier
return &service.Account{
- ID: m.ID,
- Name: m.Name,
- Notes: m.Notes,
- Platform: m.Platform,
- Type: m.Type,
- Credentials: copyJSONMap(m.Credentials),
- Extra: copyJSONMap(m.Extra),
- ProxyID: m.ProxyID,
- Concurrency: m.Concurrency,
- Priority: m.Priority,
- RateMultiplier: &rateMultiplier,
- Status: m.Status,
- ErrorMessage: derefString(m.ErrorMessage),
- LastUsedAt: m.LastUsedAt,
- ExpiresAt: m.ExpiresAt,
- AutoPauseOnExpired: m.AutoPauseOnExpired,
- CreatedAt: m.CreatedAt,
- UpdatedAt: m.UpdatedAt,
- Schedulable: m.Schedulable,
- RateLimitedAt: m.RateLimitedAt,
- RateLimitResetAt: m.RateLimitResetAt,
- OverloadUntil: m.OverloadUntil,
- SessionWindowStart: m.SessionWindowStart,
- SessionWindowEnd: m.SessionWindowEnd,
- SessionWindowStatus: derefString(m.SessionWindowStatus),
+ ID: m.ID,
+ Name: m.Name,
+ Notes: m.Notes,
+ Platform: m.Platform,
+ Type: m.Type,
+ Credentials: copyJSONMap(m.Credentials),
+ Extra: copyJSONMap(m.Extra),
+ ProxyID: m.ProxyID,
+ Concurrency: m.Concurrency,
+ Priority: m.Priority,
+ RateMultiplier: &rateMultiplier,
+ Status: m.Status,
+ ErrorMessage: derefString(m.ErrorMessage),
+ LastUsedAt: m.LastUsedAt,
+ ExpiresAt: m.ExpiresAt,
+ AutoPauseOnExpired: m.AutoPauseOnExpired,
+ CreatedAt: m.CreatedAt,
+ UpdatedAt: m.UpdatedAt,
+ Schedulable: m.Schedulable,
+ RateLimitedAt: m.RateLimitedAt,
+ RateLimitResetAt: m.RateLimitResetAt,
+ OverloadUntil: m.OverloadUntil,
+ TempUnschedulableUntil: m.TempUnschedulableUntil,
+ TempUnschedulableReason: derefString(m.TempUnschedulableReason),
+ SessionWindowStart: m.SessionWindowStart,
+ SessionWindowEnd: m.SessionWindowEnd,
+ SessionWindowStatus: derefString(m.SessionWindowStatus),
}
}
diff --git a/backend/internal/repository/account_repo_integration_test.go b/backend/internal/repository/account_repo_integration_test.go
index 4f9d0152..fd48a5d4 100644
--- a/backend/internal/repository/account_repo_integration_test.go
+++ b/backend/internal/repository/account_repo_integration_test.go
@@ -500,6 +500,38 @@ func (s *AccountRepoSuite) TestClearRateLimit() {
s.Require().Nil(got.OverloadUntil)
}
+func (s *AccountRepoSuite) TestTempUnschedulableFieldsLoadedByGetByIDAndGetByIDs() {
+ acc1 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-temp-1"})
+ acc2 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-temp-2"})
+
+ until := time.Now().Add(15 * time.Minute).UTC().Truncate(time.Second)
+ reason := `{"rule":"429","matched_keyword":"too many requests"}`
+ s.Require().NoError(s.repo.SetTempUnschedulable(s.ctx, acc1.ID, until, reason))
+
+ gotByID, err := s.repo.GetByID(s.ctx, acc1.ID)
+ s.Require().NoError(err)
+ s.Require().NotNil(gotByID.TempUnschedulableUntil)
+ s.Require().WithinDuration(until, *gotByID.TempUnschedulableUntil, time.Second)
+ s.Require().Equal(reason, gotByID.TempUnschedulableReason)
+
+ gotByIDs, err := s.repo.GetByIDs(s.ctx, []int64{acc2.ID, acc1.ID})
+ s.Require().NoError(err)
+ s.Require().Len(gotByIDs, 2)
+ s.Require().Equal(acc2.ID, gotByIDs[0].ID)
+ s.Require().Nil(gotByIDs[0].TempUnschedulableUntil)
+ s.Require().Equal("", gotByIDs[0].TempUnschedulableReason)
+ s.Require().Equal(acc1.ID, gotByIDs[1].ID)
+ s.Require().NotNil(gotByIDs[1].TempUnschedulableUntil)
+ s.Require().WithinDuration(until, *gotByIDs[1].TempUnschedulableUntil, time.Second)
+ s.Require().Equal(reason, gotByIDs[1].TempUnschedulableReason)
+
+ s.Require().NoError(s.repo.ClearTempUnschedulable(s.ctx, acc1.ID))
+ cleared, err := s.repo.GetByID(s.ctx, acc1.ID)
+ s.Require().NoError(err)
+ s.Require().Nil(cleared.TempUnschedulableUntil)
+ s.Require().Equal("", cleared.TempUnschedulableReason)
+}
+
// --- UpdateLastUsed ---
func (s *AccountRepoSuite) TestUpdateLastUsed() {
diff --git a/backend/internal/repository/api_key_repo.go b/backend/internal/repository/api_key_repo.go
index cdccd4fc..b9ce60a5 100644
--- a/backend/internal/repository/api_key_repo.go
+++ b/backend/internal/repository/api_key_repo.go
@@ -171,8 +171,9 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) erro
// 则会更新已删除的记录。
// 这里选择 Update().Where(),确保只有未软删除记录能被更新。
// 同时显式设置 updated_at,避免二次查询带来的并发可见性问题。
+ client := clientFromContext(ctx, r.client)
now := time.Now()
- builder := r.client.APIKey.Update().
+ builder := client.APIKey.Update().
Where(apikey.IDEQ(key.ID), apikey.DeletedAtIsNil()).
SetName(key.Name).
SetStatus(key.Status).
@@ -445,20 +446,22 @@ func userEntityToService(u *dbent.User) *service.User {
return nil
}
return &service.User{
- ID: u.ID,
- Email: u.Email,
- Username: u.Username,
- Notes: u.Notes,
- PasswordHash: u.PasswordHash,
- Role: u.Role,
- Balance: u.Balance,
- Concurrency: u.Concurrency,
- Status: u.Status,
- TotpSecretEncrypted: u.TotpSecretEncrypted,
- TotpEnabled: u.TotpEnabled,
- TotpEnabledAt: u.TotpEnabledAt,
- CreatedAt: u.CreatedAt,
- UpdatedAt: u.UpdatedAt,
+ ID: u.ID,
+ Email: u.Email,
+ Username: u.Username,
+ Notes: u.Notes,
+ PasswordHash: u.PasswordHash,
+ Role: u.Role,
+ Balance: u.Balance,
+ Concurrency: u.Concurrency,
+ Status: u.Status,
+ SoraStorageQuotaBytes: u.SoraStorageQuotaBytes,
+ SoraStorageUsedBytes: u.SoraStorageUsedBytes,
+ TotpSecretEncrypted: u.TotpSecretEncrypted,
+ TotpEnabled: u.TotpEnabled,
+ TotpEnabledAt: u.TotpEnabledAt,
+ CreatedAt: u.CreatedAt,
+ UpdatedAt: u.UpdatedAt,
}
}
@@ -486,6 +489,7 @@ func groupEntityToService(g *dbent.Group) *service.Group {
SoraImagePrice540: g.SoraImagePrice540,
SoraVideoPricePerRequest: g.SoraVideoPricePerRequest,
SoraVideoPricePerRequestHD: g.SoraVideoPricePerRequestHd,
+ SoraStorageQuotaBytes: g.SoraStorageQuotaBytes,
DefaultValidityDays: g.DefaultValidityDays,
ClaudeCodeOnly: g.ClaudeCodeOnly,
FallbackGroupID: g.FallbackGroupID,
diff --git a/backend/internal/repository/concurrency_cache.go b/backend/internal/repository/concurrency_cache.go
index e047bff0..a2552715 100644
--- a/backend/internal/repository/concurrency_cache.go
+++ b/backend/internal/repository/concurrency_cache.go
@@ -227,6 +227,43 @@ func (c *concurrencyCache) GetAccountConcurrency(ctx context.Context, accountID
return result, nil
}
+func (c *concurrencyCache) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {
+ if len(accountIDs) == 0 {
+ return map[int64]int{}, nil
+ }
+
+ now, err := c.rdb.Time(ctx).Result()
+ if err != nil {
+ return nil, fmt.Errorf("redis TIME: %w", err)
+ }
+ cutoffTime := now.Unix() - int64(c.slotTTLSeconds)
+
+ pipe := c.rdb.Pipeline()
+ type accountCmd struct {
+ accountID int64
+ zcardCmd *redis.IntCmd
+ }
+ cmds := make([]accountCmd, 0, len(accountIDs))
+ for _, accountID := range accountIDs {
+ slotKey := accountSlotKeyPrefix + strconv.FormatInt(accountID, 10)
+ pipe.ZRemRangeByScore(ctx, slotKey, "-inf", strconv.FormatInt(cutoffTime, 10))
+ cmds = append(cmds, accountCmd{
+ accountID: accountID,
+ zcardCmd: pipe.ZCard(ctx, slotKey),
+ })
+ }
+
+ if _, err := pipe.Exec(ctx); err != nil && !errors.Is(err, redis.Nil) {
+ return nil, fmt.Errorf("pipeline exec: %w", err)
+ }
+
+ result := make(map[int64]int, len(accountIDs))
+ for _, cmd := range cmds {
+ result[cmd.accountID] = int(cmd.zcardCmd.Val())
+ }
+ return result, nil
+}
+
// User slot operations
func (c *concurrencyCache) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
diff --git a/backend/internal/repository/gateway_cache_integration_test.go b/backend/internal/repository/gateway_cache_integration_test.go
index 2fdaa3d1..0eebc33f 100644
--- a/backend/internal/repository/gateway_cache_integration_test.go
+++ b/backend/internal/repository/gateway_cache_integration_test.go
@@ -104,7 +104,6 @@ func (s *GatewayCacheSuite) TestGetSessionAccountID_CorruptedValue() {
require.False(s.T(), errors.Is(err, redis.Nil), "expected parsing error, not redis.Nil")
}
-
func TestGatewayCacheSuite(t *testing.T) {
suite.Run(t, new(GatewayCacheSuite))
}
diff --git a/backend/internal/repository/gemini_drive_client.go b/backend/internal/repository/gemini_drive_client.go
new file mode 100644
index 00000000..2e383595
--- /dev/null
+++ b/backend/internal/repository/gemini_drive_client.go
@@ -0,0 +1,9 @@
+package repository
+
+import "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
+
+// NewGeminiDriveClient creates a concrete DriveClient for Google Drive API operations.
+// Returned as geminicli.DriveClient interface for DI (Strategy A).
+func NewGeminiDriveClient() geminicli.DriveClient {
+ return geminicli.NewDriveClient()
+}
diff --git a/backend/internal/repository/group_repo.go b/backend/internal/repository/group_repo.go
index fd239996..4edc8534 100644
--- a/backend/internal/repository/group_repo.go
+++ b/backend/internal/repository/group_repo.go
@@ -4,6 +4,8 @@ import (
"context"
"database/sql"
"errors"
+ "fmt"
+ "strings"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/apikey"
@@ -56,7 +58,8 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
SetNillableFallbackGroupID(groupIn.FallbackGroupID).
SetNillableFallbackGroupIDOnInvalidRequest(groupIn.FallbackGroupIDOnInvalidRequest).
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled).
- SetMcpXMLInject(groupIn.MCPXMLInject)
+ SetMcpXMLInject(groupIn.MCPXMLInject).
+ SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes)
// 设置模型路由配置
if groupIn.ModelRouting != nil {
@@ -121,7 +124,40 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
SetDefaultValidityDays(groupIn.DefaultValidityDays).
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly).
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled).
- SetMcpXMLInject(groupIn.MCPXMLInject)
+ SetMcpXMLInject(groupIn.MCPXMLInject).
+ SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes)
+
+ // 显式处理可空字段:nil 需要 clear,非 nil 需要 set。
+ if groupIn.DailyLimitUSD != nil {
+ builder = builder.SetDailyLimitUsd(*groupIn.DailyLimitUSD)
+ } else {
+ builder = builder.ClearDailyLimitUsd()
+ }
+ if groupIn.WeeklyLimitUSD != nil {
+ builder = builder.SetWeeklyLimitUsd(*groupIn.WeeklyLimitUSD)
+ } else {
+ builder = builder.ClearWeeklyLimitUsd()
+ }
+ if groupIn.MonthlyLimitUSD != nil {
+ builder = builder.SetMonthlyLimitUsd(*groupIn.MonthlyLimitUSD)
+ } else {
+ builder = builder.ClearMonthlyLimitUsd()
+ }
+ if groupIn.ImagePrice1K != nil {
+ builder = builder.SetImagePrice1k(*groupIn.ImagePrice1K)
+ } else {
+ builder = builder.ClearImagePrice1k()
+ }
+ if groupIn.ImagePrice2K != nil {
+ builder = builder.SetImagePrice2k(*groupIn.ImagePrice2K)
+ } else {
+ builder = builder.ClearImagePrice2k()
+ }
+ if groupIn.ImagePrice4K != nil {
+ builder = builder.SetImagePrice4k(*groupIn.ImagePrice4K)
+ } else {
+ builder = builder.ClearImagePrice4k()
+ }
// 处理 FallbackGroupID:nil 时清除,否则设置
if groupIn.FallbackGroupID != nil {
@@ -281,6 +317,54 @@ func (r *groupRepository) ExistsByName(ctx context.Context, name string) (bool,
return r.client.Group.Query().Where(group.NameEQ(name)).Exist(ctx)
}
+// ExistsByIDs 批量检查分组是否存在(仅检查未软删除记录)。
+// 返回结构:map[groupID]exists。
+func (r *groupRepository) ExistsByIDs(ctx context.Context, ids []int64) (map[int64]bool, error) {
+ result := make(map[int64]bool, len(ids))
+ if len(ids) == 0 {
+ return result, nil
+ }
+
+ uniqueIDs := make([]int64, 0, len(ids))
+ seen := make(map[int64]struct{}, len(ids))
+ for _, id := range ids {
+ if id <= 0 {
+ continue
+ }
+ if _, ok := seen[id]; ok {
+ continue
+ }
+ seen[id] = struct{}{}
+ uniqueIDs = append(uniqueIDs, id)
+ result[id] = false
+ }
+ if len(uniqueIDs) == 0 {
+ return result, nil
+ }
+
+ rows, err := r.sql.QueryContext(ctx, `
+ SELECT id
+ FROM groups
+ WHERE id = ANY($1) AND deleted_at IS NULL
+ `, pq.Array(uniqueIDs))
+ if err != nil {
+ return nil, err
+ }
+ defer func() { _ = rows.Close() }()
+
+ for rows.Next() {
+ var id int64
+ if err := rows.Scan(&id); err != nil {
+ return nil, err
+ }
+ result[id] = true
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+ return result, nil
+}
+
func (r *groupRepository) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
var count int64
if err := scanSingleRow(ctx, r.sql, "SELECT COUNT(*) FROM account_groups WHERE group_id = $1", []any{groupID}, &count); err != nil {
@@ -512,22 +596,72 @@ func (r *groupRepository) UpdateSortOrders(ctx context.Context, updates []servic
return nil
}
- // 使用事务批量更新
- tx, err := r.client.Tx(ctx)
+ // 去重后保留最后一次排序值,避免重复 ID 造成 CASE 分支冲突。
+ sortOrderByID := make(map[int64]int, len(updates))
+ groupIDs := make([]int64, 0, len(updates))
+ for _, u := range updates {
+ if u.ID <= 0 {
+ continue
+ }
+ if _, exists := sortOrderByID[u.ID]; !exists {
+ groupIDs = append(groupIDs, u.ID)
+ }
+ sortOrderByID[u.ID] = u.SortOrder
+ }
+ if len(groupIDs) == 0 {
+ return nil
+ }
+
+ // 与旧实现保持一致:任何不存在/已删除的分组都返回 not found,且不执行更新。
+ var existingCount int
+ if err := scanSingleRow(
+ ctx,
+ r.sql,
+ `SELECT COUNT(*) FROM groups WHERE deleted_at IS NULL AND id = ANY($1)`,
+ []any{pq.Array(groupIDs)},
+ &existingCount,
+ ); err != nil {
+ return err
+ }
+ if existingCount != len(groupIDs) {
+ return service.ErrGroupNotFound
+ }
+
+ args := make([]any, 0, len(groupIDs)*2+1)
+ caseClauses := make([]string, 0, len(groupIDs))
+ placeholder := 1
+ for _, id := range groupIDs {
+ caseClauses = append(caseClauses, fmt.Sprintf("WHEN $%d THEN $%d", placeholder, placeholder+1))
+ args = append(args, id, sortOrderByID[id])
+ placeholder += 2
+ }
+ args = append(args, pq.Array(groupIDs))
+
+ query := fmt.Sprintf(`
+ UPDATE groups
+ SET sort_order = CASE id
+ %s
+ ELSE sort_order
+ END
+ WHERE deleted_at IS NULL AND id = ANY($%d)
+ `, strings.Join(caseClauses, "\n\t\t\t"), placeholder)
+
+ result, err := r.sql.ExecContext(ctx, query, args...)
if err != nil {
return err
}
- defer func() { _ = tx.Rollback() }()
-
- for _, u := range updates {
- if _, err := tx.Group.UpdateOneID(u.ID).SetSortOrder(u.SortOrder).Save(ctx); err != nil {
- return translatePersistenceError(err, service.ErrGroupNotFound, nil)
- }
- }
-
- if err := tx.Commit(); err != nil {
+ affected, err := result.RowsAffected()
+ if err != nil {
return err
}
+ if affected != int64(len(groupIDs)) {
+ return service.ErrGroupNotFound
+ }
+ for _, id := range groupIDs {
+ if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &id, nil); err != nil {
+ logger.LegacyPrintf("repository.group", "[SchedulerOutbox] enqueue group sort update failed: group=%d err=%v", id, err)
+ }
+ }
return nil
}
diff --git a/backend/internal/repository/group_repo_integration_test.go b/backend/internal/repository/group_repo_integration_test.go
index c31a9ec4..4a849a46 100644
--- a/backend/internal/repository/group_repo_integration_test.go
+++ b/backend/internal/repository/group_repo_integration_test.go
@@ -352,6 +352,81 @@ func (s *GroupRepoSuite) TestListWithFilters_Search() {
})
}
+func (s *GroupRepoSuite) TestUpdateSortOrders_BatchCaseWhen() {
+ g1 := &service.Group{
+ Name: "sort-g1",
+ Platform: service.PlatformAnthropic,
+ RateMultiplier: 1.0,
+ IsExclusive: false,
+ Status: service.StatusActive,
+ SubscriptionType: service.SubscriptionTypeStandard,
+ }
+ g2 := &service.Group{
+ Name: "sort-g2",
+ Platform: service.PlatformAnthropic,
+ RateMultiplier: 1.0,
+ IsExclusive: false,
+ Status: service.StatusActive,
+ SubscriptionType: service.SubscriptionTypeStandard,
+ }
+ g3 := &service.Group{
+ Name: "sort-g3",
+ Platform: service.PlatformAnthropic,
+ RateMultiplier: 1.0,
+ IsExclusive: false,
+ Status: service.StatusActive,
+ SubscriptionType: service.SubscriptionTypeStandard,
+ }
+ s.Require().NoError(s.repo.Create(s.ctx, g1))
+ s.Require().NoError(s.repo.Create(s.ctx, g2))
+ s.Require().NoError(s.repo.Create(s.ctx, g3))
+
+ err := s.repo.UpdateSortOrders(s.ctx, []service.GroupSortOrderUpdate{
+ {ID: g1.ID, SortOrder: 30},
+ {ID: g2.ID, SortOrder: 10},
+ {ID: g3.ID, SortOrder: 20},
+ {ID: g2.ID, SortOrder: 15}, // 重复 ID 应以最后一次为准
+ })
+ s.Require().NoError(err)
+
+ got1, err := s.repo.GetByID(s.ctx, g1.ID)
+ s.Require().NoError(err)
+ got2, err := s.repo.GetByID(s.ctx, g2.ID)
+ s.Require().NoError(err)
+ got3, err := s.repo.GetByID(s.ctx, g3.ID)
+ s.Require().NoError(err)
+ s.Require().Equal(30, got1.SortOrder)
+ s.Require().Equal(15, got2.SortOrder)
+ s.Require().Equal(20, got3.SortOrder)
+}
+
+func (s *GroupRepoSuite) TestUpdateSortOrders_MissingGroupNoPartialUpdate() {
+ g1 := &service.Group{
+ Name: "sort-no-partial",
+ Platform: service.PlatformAnthropic,
+ RateMultiplier: 1.0,
+ IsExclusive: false,
+ Status: service.StatusActive,
+ SubscriptionType: service.SubscriptionTypeStandard,
+ }
+ s.Require().NoError(s.repo.Create(s.ctx, g1))
+
+ before, err := s.repo.GetByID(s.ctx, g1.ID)
+ s.Require().NoError(err)
+ beforeSort := before.SortOrder
+
+ err = s.repo.UpdateSortOrders(s.ctx, []service.GroupSortOrderUpdate{
+ {ID: g1.ID, SortOrder: 99},
+ {ID: 99999999, SortOrder: 1},
+ })
+ s.Require().Error(err)
+ s.Require().ErrorIs(err, service.ErrGroupNotFound)
+
+ after, err := s.repo.GetByID(s.ctx, g1.ID)
+ s.Require().NoError(err)
+ s.Require().Equal(beforeSort, after.SortOrder)
+}
+
func (s *GroupRepoSuite) TestListWithFilters_AccountCount() {
g1 := &service.Group{
Name: "g1",
diff --git a/backend/internal/repository/idempotency_repo_integration_test.go b/backend/internal/repository/idempotency_repo_integration_test.go
index 23b52726..f163c2f0 100644
--- a/backend/internal/repository/idempotency_repo_integration_test.go
+++ b/backend/internal/repository/idempotency_repo_integration_test.go
@@ -147,4 +147,3 @@ func TestIdempotencyRepo_StatusTransition_ToSucceeded(t *testing.T) {
require.Equal(t, `{"ok":true}`, *got.ResponseBody)
require.Nil(t, got.LockedUntil)
}
-
diff --git a/backend/internal/repository/identity_cache.go b/backend/internal/repository/identity_cache.go
index c4986547..6152dd7a 100644
--- a/backend/internal/repository/identity_cache.go
+++ b/backend/internal/repository/identity_cache.go
@@ -12,7 +12,7 @@ import (
const (
fingerprintKeyPrefix = "fingerprint:"
- fingerprintTTL = 24 * time.Hour
+ fingerprintTTL = 7 * 24 * time.Hour // 7天,配合每24小时懒续期可保持活跃账号永不过期
maskedSessionKeyPrefix = "masked_session:"
maskedSessionTTL = 15 * time.Minute
)
diff --git a/backend/internal/repository/migrations_runner.go b/backend/internal/repository/migrations_runner.go
index 5912e50f..a60ba294 100644
--- a/backend/internal/repository/migrations_runner.go
+++ b/backend/internal/repository/migrations_runner.go
@@ -50,6 +50,23 @@ CREATE TABLE IF NOT EXISTS atlas_schema_revisions (
// 任何稳定的 int64 值都可以,只要不与同一数据库中的其他锁冲突即可。
const migrationsAdvisoryLockID int64 = 694208311321144027
const migrationsLockRetryInterval = 500 * time.Millisecond
+const nonTransactionalMigrationSuffix = "_notx.sql"
+
+type migrationChecksumCompatibilityRule struct {
+ fileChecksum string
+ acceptedDBChecksum map[string]struct{}
+}
+
+// migrationChecksumCompatibilityRules 仅用于兼容历史上误修改过的迁移文件 checksum。
+// 规则必须同时匹配「迁移名 + 当前文件 checksum + 历史库 checksum」才会放行,避免放宽全局校验。
+var migrationChecksumCompatibilityRules = map[string]migrationChecksumCompatibilityRule{
+ "054_drop_legacy_cache_columns.sql": {
+ fileChecksum: "82de761156e03876653e7a6a4eee883cd927847036f779b0b9f34c42a8af7a7d",
+ acceptedDBChecksum: map[string]struct{}{
+ "182c193f3359946cf094090cd9e57d5c3fd9abaffbc1e8fc378646b8a6fa12b4": {},
+ },
+ },
+}
// ApplyMigrations 将嵌入的 SQL 迁移文件应用到指定的数据库。
//
@@ -147,6 +164,10 @@ func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error {
if rowErr == nil {
// 迁移已应用,验证校验和是否匹配
if existing != checksum {
+ // 兼容特定历史误改场景(仅白名单规则),其余仍保持严格不可变约束。
+ if isMigrationChecksumCompatible(name, existing, checksum) {
+ continue
+ }
// 校验和不匹配意味着迁移文件在应用后被修改,这是危险的。
// 正确的做法是创建新的迁移文件来进行变更。
return fmt.Errorf(
@@ -165,8 +186,34 @@ func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error {
return fmt.Errorf("check migration %s: %w", name, rowErr)
}
- // 迁移未应用,在事务中执行。
- // 使用事务确保迁移的原子性:要么完全成功,要么完全回滚。
+ nonTx, err := validateMigrationExecutionMode(name, content)
+ if err != nil {
+ return fmt.Errorf("validate migration %s: %w", name, err)
+ }
+
+ if nonTx {
+ // *_notx.sql:用于 CREATE/DROP INDEX CONCURRENTLY 场景,必须非事务执行。
+ // 逐条语句执行,避免将多条 CONCURRENTLY 语句放入同一个隐式事务块。
+ statements := splitSQLStatements(content)
+ for i, stmt := range statements {
+ trimmed := strings.TrimSpace(stmt)
+ if trimmed == "" {
+ continue
+ }
+ if stripSQLLineComment(trimmed) == "" {
+ continue
+ }
+ if _, err := db.ExecContext(ctx, trimmed); err != nil {
+ return fmt.Errorf("apply migration %s (non-tx statement %d): %w", name, i+1, err)
+ }
+ }
+ if _, err := db.ExecContext(ctx, "INSERT INTO schema_migrations (filename, checksum) VALUES ($1, $2)", name, checksum); err != nil {
+ return fmt.Errorf("record migration %s (non-tx): %w", name, err)
+ }
+ continue
+ }
+
+ // 默认迁移在事务中执行,确保原子性:要么完全成功,要么完全回滚。
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("begin migration %s: %w", name, err)
@@ -268,6 +315,84 @@ func latestMigrationBaseline(fsys fs.FS) (string, string, string, error) {
return version, version, hash, nil
}
+func isMigrationChecksumCompatible(name, dbChecksum, fileChecksum string) bool {
+ rule, ok := migrationChecksumCompatibilityRules[name]
+ if !ok {
+ return false
+ }
+ if rule.fileChecksum != fileChecksum {
+ return false
+ }
+ _, ok = rule.acceptedDBChecksum[dbChecksum]
+ return ok
+}
+
+func validateMigrationExecutionMode(name, content string) (bool, error) {
+ normalizedName := strings.ToLower(strings.TrimSpace(name))
+ upperContent := strings.ToUpper(content)
+ nonTx := strings.HasSuffix(normalizedName, nonTransactionalMigrationSuffix)
+
+ if !nonTx {
+ if strings.Contains(upperContent, "CONCURRENTLY") {
+ return false, errors.New("CONCURRENTLY statements must be placed in *_notx.sql migrations")
+ }
+ return false, nil
+ }
+
+ if strings.Contains(upperContent, "BEGIN") || strings.Contains(upperContent, "COMMIT") || strings.Contains(upperContent, "ROLLBACK") {
+ return false, errors.New("*_notx.sql must not contain transaction control statements (BEGIN/COMMIT/ROLLBACK)")
+ }
+
+ statements := splitSQLStatements(content)
+ for _, stmt := range statements {
+ normalizedStmt := strings.ToUpper(stripSQLLineComment(strings.TrimSpace(stmt)))
+ if normalizedStmt == "" {
+ continue
+ }
+
+ if strings.Contains(normalizedStmt, "CONCURRENTLY") {
+ isCreateIndex := strings.Contains(normalizedStmt, "CREATE") && strings.Contains(normalizedStmt, "INDEX")
+ isDropIndex := strings.Contains(normalizedStmt, "DROP") && strings.Contains(normalizedStmt, "INDEX")
+ if !isCreateIndex && !isDropIndex {
+ return false, errors.New("*_notx.sql currently only supports CREATE/DROP INDEX CONCURRENTLY statements")
+ }
+ if isCreateIndex && !strings.Contains(normalizedStmt, "IF NOT EXISTS") {
+ return false, errors.New("CREATE INDEX CONCURRENTLY in *_notx.sql must include IF NOT EXISTS for idempotency")
+ }
+ if isDropIndex && !strings.Contains(normalizedStmt, "IF EXISTS") {
+ return false, errors.New("DROP INDEX CONCURRENTLY in *_notx.sql must include IF EXISTS for idempotency")
+ }
+ continue
+ }
+
+ return false, errors.New("*_notx.sql must not mix non-CONCURRENTLY SQL statements")
+ }
+
+ return true, nil
+}
+
+func splitSQLStatements(content string) []string {
+ parts := strings.Split(content, ";")
+ out := make([]string, 0, len(parts))
+ for _, part := range parts {
+ if strings.TrimSpace(part) == "" {
+ continue
+ }
+ out = append(out, part)
+ }
+ return out
+}
+
+func stripSQLLineComment(s string) string {
+ lines := strings.Split(s, "\n")
+ for i, line := range lines {
+ if idx := strings.Index(line, "--"); idx >= 0 {
+ lines[i] = line[:idx]
+ }
+ }
+ return strings.TrimSpace(strings.Join(lines, "\n"))
+}
+
// pgAdvisoryLock 获取 PostgreSQL Advisory Lock。
// Advisory Lock 是一种轻量级的锁机制,不与任何特定的数据库对象关联。
// 它非常适合用于应用层面的分布式锁场景,如迁移序列化。
diff --git a/backend/internal/repository/migrations_runner_checksum_test.go b/backend/internal/repository/migrations_runner_checksum_test.go
new file mode 100644
index 00000000..54f5b0ec
--- /dev/null
+++ b/backend/internal/repository/migrations_runner_checksum_test.go
@@ -0,0 +1,36 @@
+package repository
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestIsMigrationChecksumCompatible(t *testing.T) {
+ t.Run("054历史checksum可兼容", func(t *testing.T) {
+ ok := isMigrationChecksumCompatible(
+ "054_drop_legacy_cache_columns.sql",
+ "182c193f3359946cf094090cd9e57d5c3fd9abaffbc1e8fc378646b8a6fa12b4",
+ "82de761156e03876653e7a6a4eee883cd927847036f779b0b9f34c42a8af7a7d",
+ )
+ require.True(t, ok)
+ })
+
+ t.Run("054在未知文件checksum下不兼容", func(t *testing.T) {
+ ok := isMigrationChecksumCompatible(
+ "054_drop_legacy_cache_columns.sql",
+ "182c193f3359946cf094090cd9e57d5c3fd9abaffbc1e8fc378646b8a6fa12b4",
+ "0000000000000000000000000000000000000000000000000000000000000000",
+ )
+ require.False(t, ok)
+ })
+
+ t.Run("非白名单迁移不兼容", func(t *testing.T) {
+ ok := isMigrationChecksumCompatible(
+ "001_init.sql",
+ "182c193f3359946cf094090cd9e57d5c3fd9abaffbc1e8fc378646b8a6fa12b4",
+ "82de761156e03876653e7a6a4eee883cd927847036f779b0b9f34c42a8af7a7d",
+ )
+ require.False(t, ok)
+ })
+}
diff --git a/backend/internal/repository/migrations_runner_extra_test.go b/backend/internal/repository/migrations_runner_extra_test.go
new file mode 100644
index 00000000..9f8a94c6
--- /dev/null
+++ b/backend/internal/repository/migrations_runner_extra_test.go
@@ -0,0 +1,368 @@
+package repository
+
+import (
+ "context"
+ "crypto/sha256"
+ "encoding/hex"
+ "errors"
+ "io/fs"
+ "strings"
+ "testing"
+ "testing/fstest"
+ "time"
+
+ sqlmock "github.com/DATA-DOG/go-sqlmock"
+ "github.com/stretchr/testify/require"
+)
+
+func TestApplyMigrations_NilDB(t *testing.T) {
+ err := ApplyMigrations(context.Background(), nil)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "nil sql db")
+}
+
+func TestApplyMigrations_DelegatesToApplyMigrationsFS(t *testing.T) {
+ db, mock, err := sqlmock.New()
+ require.NoError(t, err)
+ defer func() { _ = db.Close() }()
+
+ mock.ExpectQuery("SELECT pg_try_advisory_lock\\(\\$1\\)").
+ WithArgs(migrationsAdvisoryLockID).
+ WillReturnError(errors.New("lock failed"))
+
+ err = ApplyMigrations(context.Background(), db)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "acquire migrations lock")
+ require.NoError(t, mock.ExpectationsWereMet())
+}
+
+func TestLatestMigrationBaseline(t *testing.T) {
+ t.Run("empty_fs_returns_baseline", func(t *testing.T) {
+ version, description, hash, err := latestMigrationBaseline(fstest.MapFS{})
+ require.NoError(t, err)
+ require.Equal(t, "baseline", version)
+ require.Equal(t, "baseline", description)
+ require.Equal(t, "", hash)
+ })
+
+ t.Run("uses_latest_sorted_sql_file", func(t *testing.T) {
+ fsys := fstest.MapFS{
+ "001_init.sql": &fstest.MapFile{Data: []byte("CREATE TABLE t1(id int);")},
+ "010_final.sql": &fstest.MapFile{
+ Data: []byte("CREATE TABLE t2(id int);"),
+ },
+ }
+ version, description, hash, err := latestMigrationBaseline(fsys)
+ require.NoError(t, err)
+ require.Equal(t, "010_final", version)
+ require.Equal(t, "010_final", description)
+ require.Len(t, hash, 64)
+ })
+
+ t.Run("read_file_error", func(t *testing.T) {
+ fsys := fstest.MapFS{
+ "010_bad.sql": &fstest.MapFile{Mode: fs.ModeDir},
+ }
+ _, _, _, err := latestMigrationBaseline(fsys)
+ require.Error(t, err)
+ })
+}
+
+func TestIsMigrationChecksumCompatible_AdditionalCases(t *testing.T) {
+ require.False(t, isMigrationChecksumCompatible("unknown.sql", "db", "file"))
+
+ var (
+ name string
+ rule migrationChecksumCompatibilityRule
+ )
+ for n, r := range migrationChecksumCompatibilityRules {
+ name = n
+ rule = r
+ break
+ }
+ require.NotEmpty(t, name)
+
+ require.False(t, isMigrationChecksumCompatible(name, "db-not-accepted", "file-not-match"))
+ require.False(t, isMigrationChecksumCompatible(name, "db-not-accepted", rule.fileChecksum))
+
+ var accepted string
+ for checksum := range rule.acceptedDBChecksum {
+ accepted = checksum
+ break
+ }
+ require.NotEmpty(t, accepted)
+ require.True(t, isMigrationChecksumCompatible(name, accepted, rule.fileChecksum))
+}
+
+func TestEnsureAtlasBaselineAligned(t *testing.T) {
+ t.Run("skip_when_no_legacy_table", func(t *testing.T) {
+ db, mock, err := sqlmock.New()
+ require.NoError(t, err)
+ defer func() { _ = db.Close() }()
+
+ mock.ExpectQuery("SELECT EXISTS \\(").
+ WithArgs("schema_migrations").
+ WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(false))
+
+ err = ensureAtlasBaselineAligned(context.Background(), db, fstest.MapFS{})
+ require.NoError(t, err)
+ require.NoError(t, mock.ExpectationsWereMet())
+ })
+
+ t.Run("create_atlas_and_insert_baseline_when_empty", func(t *testing.T) {
+ db, mock, err := sqlmock.New()
+ require.NoError(t, err)
+ defer func() { _ = db.Close() }()
+
+ mock.ExpectQuery("SELECT EXISTS \\(").
+ WithArgs("schema_migrations").
+ WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
+ mock.ExpectQuery("SELECT EXISTS \\(").
+ WithArgs("atlas_schema_revisions").
+ WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(false))
+ mock.ExpectExec("CREATE TABLE IF NOT EXISTS atlas_schema_revisions").
+ WillReturnResult(sqlmock.NewResult(0, 0))
+ mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM atlas_schema_revisions").
+ WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0))
+ mock.ExpectExec("INSERT INTO atlas_schema_revisions").
+ WithArgs("002_next", "002_next", 1, sqlmock.AnyArg()).
+ WillReturnResult(sqlmock.NewResult(1, 1))
+
+ fsys := fstest.MapFS{
+ "001_init.sql": &fstest.MapFile{Data: []byte("CREATE TABLE t1(id int);")},
+ "002_next.sql": &fstest.MapFile{Data: []byte("CREATE TABLE t2(id int);")},
+ }
+ err = ensureAtlasBaselineAligned(context.Background(), db, fsys)
+ require.NoError(t, err)
+ require.NoError(t, mock.ExpectationsWereMet())
+ })
+
+ t.Run("error_when_checking_legacy_table", func(t *testing.T) {
+ db, mock, err := sqlmock.New()
+ require.NoError(t, err)
+ defer func() { _ = db.Close() }()
+
+ mock.ExpectQuery("SELECT EXISTS \\(").
+ WithArgs("schema_migrations").
+ WillReturnError(errors.New("exists failed"))
+
+ err = ensureAtlasBaselineAligned(context.Background(), db, fstest.MapFS{})
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "check schema_migrations")
+ require.NoError(t, mock.ExpectationsWereMet())
+ })
+
+ t.Run("error_when_counting_atlas_rows", func(t *testing.T) {
+ db, mock, err := sqlmock.New()
+ require.NoError(t, err)
+ defer func() { _ = db.Close() }()
+
+ mock.ExpectQuery("SELECT EXISTS \\(").
+ WithArgs("schema_migrations").
+ WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
+ mock.ExpectQuery("SELECT EXISTS \\(").
+ WithArgs("atlas_schema_revisions").
+ WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
+ mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM atlas_schema_revisions").
+ WillReturnError(errors.New("count failed"))
+
+ err = ensureAtlasBaselineAligned(context.Background(), db, fstest.MapFS{})
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "count atlas_schema_revisions")
+ require.NoError(t, mock.ExpectationsWereMet())
+ })
+
+ t.Run("error_when_creating_atlas_table", func(t *testing.T) {
+ db, mock, err := sqlmock.New()
+ require.NoError(t, err)
+ defer func() { _ = db.Close() }()
+
+ mock.ExpectQuery("SELECT EXISTS \\(").
+ WithArgs("schema_migrations").
+ WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
+ mock.ExpectQuery("SELECT EXISTS \\(").
+ WithArgs("atlas_schema_revisions").
+ WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(false))
+ mock.ExpectExec("CREATE TABLE IF NOT EXISTS atlas_schema_revisions").
+ WillReturnError(errors.New("create failed"))
+
+ err = ensureAtlasBaselineAligned(context.Background(), db, fstest.MapFS{})
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "create atlas_schema_revisions")
+ require.NoError(t, mock.ExpectationsWereMet())
+ })
+
+ t.Run("error_when_inserting_baseline", func(t *testing.T) {
+ db, mock, err := sqlmock.New()
+ require.NoError(t, err)
+ defer func() { _ = db.Close() }()
+
+ mock.ExpectQuery("SELECT EXISTS \\(").
+ WithArgs("schema_migrations").
+ WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
+ mock.ExpectQuery("SELECT EXISTS \\(").
+ WithArgs("atlas_schema_revisions").
+ WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
+ mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM atlas_schema_revisions").
+ WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0))
+ mock.ExpectExec("INSERT INTO atlas_schema_revisions").
+ WithArgs("001_init", "001_init", 1, sqlmock.AnyArg()).
+ WillReturnError(errors.New("insert failed"))
+
+ fsys := fstest.MapFS{
+ "001_init.sql": &fstest.MapFile{Data: []byte("CREATE TABLE t(id int);")},
+ }
+ err = ensureAtlasBaselineAligned(context.Background(), db, fsys)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "insert atlas baseline")
+ require.NoError(t, mock.ExpectationsWereMet())
+ })
+}
+
+func TestApplyMigrationsFS_ChecksumMismatchRejected(t *testing.T) {
+ db, mock, err := sqlmock.New()
+ require.NoError(t, err)
+ defer func() { _ = db.Close() }()
+
+ prepareMigrationsBootstrapExpectations(mock)
+ mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1").
+ WithArgs("001_init.sql").
+ WillReturnRows(sqlmock.NewRows([]string{"checksum"}).AddRow("mismatched-checksum"))
+ mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)").
+ WithArgs(migrationsAdvisoryLockID).
+ WillReturnResult(sqlmock.NewResult(0, 1))
+
+ fsys := fstest.MapFS{
+ "001_init.sql": &fstest.MapFile{Data: []byte("CREATE TABLE t(id int);")},
+ }
+ err = applyMigrationsFS(context.Background(), db, fsys)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "checksum mismatch")
+ require.NoError(t, mock.ExpectationsWereMet())
+}
+
+func TestApplyMigrationsFS_CheckMigrationQueryError(t *testing.T) {
+ db, mock, err := sqlmock.New()
+ require.NoError(t, err)
+ defer func() { _ = db.Close() }()
+
+ prepareMigrationsBootstrapExpectations(mock)
+ mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1").
+ WithArgs("001_err.sql").
+ WillReturnError(errors.New("query failed"))
+ mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)").
+ WithArgs(migrationsAdvisoryLockID).
+ WillReturnResult(sqlmock.NewResult(0, 1))
+
+ fsys := fstest.MapFS{
+ "001_err.sql": &fstest.MapFile{Data: []byte("SELECT 1;")},
+ }
+ err = applyMigrationsFS(context.Background(), db, fsys)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "check migration 001_err.sql")
+ require.NoError(t, mock.ExpectationsWereMet())
+}
+
+func TestApplyMigrationsFS_SkipEmptyAndAlreadyApplied(t *testing.T) {
+ db, mock, err := sqlmock.New()
+ require.NoError(t, err)
+ defer func() { _ = db.Close() }()
+
+ prepareMigrationsBootstrapExpectations(mock)
+
+ alreadySQL := "CREATE TABLE t(id int);"
+ checksum := migrationChecksum(alreadySQL)
+ mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1").
+ WithArgs("001_already.sql").
+ WillReturnRows(sqlmock.NewRows([]string{"checksum"}).AddRow(checksum))
+ mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)").
+ WithArgs(migrationsAdvisoryLockID).
+ WillReturnResult(sqlmock.NewResult(0, 1))
+
+ fsys := fstest.MapFS{
+ "000_empty.sql": &fstest.MapFile{Data: []byte(" \n\t ")},
+ "001_already.sql": &fstest.MapFile{Data: []byte(alreadySQL)},
+ }
+ err = applyMigrationsFS(context.Background(), db, fsys)
+ require.NoError(t, err)
+ require.NoError(t, mock.ExpectationsWereMet())
+}
+
+func TestApplyMigrationsFS_ReadMigrationError(t *testing.T) {
+ db, mock, err := sqlmock.New()
+ require.NoError(t, err)
+ defer func() { _ = db.Close() }()
+
+ prepareMigrationsBootstrapExpectations(mock)
+ mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)").
+ WithArgs(migrationsAdvisoryLockID).
+ WillReturnResult(sqlmock.NewResult(0, 1))
+
+ fsys := fstest.MapFS{
+ "001_bad.sql": &fstest.MapFile{Mode: fs.ModeDir},
+ }
+ err = applyMigrationsFS(context.Background(), db, fsys)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "read migration 001_bad.sql")
+ require.NoError(t, mock.ExpectationsWereMet())
+}
+
+func TestPgAdvisoryLockAndUnlock_ErrorBranches(t *testing.T) {
+ t.Run("context_cancelled_while_not_locked", func(t *testing.T) {
+ db, mock, err := sqlmock.New()
+ require.NoError(t, err)
+ defer func() { _ = db.Close() }()
+
+ mock.ExpectQuery("SELECT pg_try_advisory_lock\\(\\$1\\)").
+ WithArgs(migrationsAdvisoryLockID).
+ WillReturnRows(sqlmock.NewRows([]string{"pg_try_advisory_lock"}).AddRow(false))
+
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond)
+ defer cancel()
+ err = pgAdvisoryLock(ctx, db)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "acquire migrations lock")
+ require.NoError(t, mock.ExpectationsWereMet())
+ })
+
+ t.Run("unlock_exec_error", func(t *testing.T) {
+ db, mock, err := sqlmock.New()
+ require.NoError(t, err)
+ defer func() { _ = db.Close() }()
+
+ mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)").
+ WithArgs(migrationsAdvisoryLockID).
+ WillReturnError(errors.New("unlock failed"))
+
+ err = pgAdvisoryUnlock(context.Background(), db)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "release migrations lock")
+ require.NoError(t, mock.ExpectationsWereMet())
+ })
+
+ t.Run("acquire_lock_after_retry", func(t *testing.T) {
+ db, mock, err := sqlmock.New()
+ require.NoError(t, err)
+ defer func() { _ = db.Close() }()
+
+ mock.ExpectQuery("SELECT pg_try_advisory_lock\\(\\$1\\)").
+ WithArgs(migrationsAdvisoryLockID).
+ WillReturnRows(sqlmock.NewRows([]string{"pg_try_advisory_lock"}).AddRow(false))
+ mock.ExpectQuery("SELECT pg_try_advisory_lock\\(\\$1\\)").
+ WithArgs(migrationsAdvisoryLockID).
+ WillReturnRows(sqlmock.NewRows([]string{"pg_try_advisory_lock"}).AddRow(true))
+
+ ctx, cancel := context.WithTimeout(context.Background(), migrationsLockRetryInterval*3)
+ defer cancel()
+ start := time.Now()
+ err = pgAdvisoryLock(ctx, db)
+ require.NoError(t, err)
+ require.GreaterOrEqual(t, time.Since(start), migrationsLockRetryInterval)
+ require.NoError(t, mock.ExpectationsWereMet())
+ })
+}
+
+func migrationChecksum(content string) string {
+ sum := sha256.Sum256([]byte(strings.TrimSpace(content)))
+ return hex.EncodeToString(sum[:])
+}
diff --git a/backend/internal/repository/migrations_runner_notx_test.go b/backend/internal/repository/migrations_runner_notx_test.go
new file mode 100644
index 00000000..db1183cd
--- /dev/null
+++ b/backend/internal/repository/migrations_runner_notx_test.go
@@ -0,0 +1,164 @@
+package repository
+
+import (
+ "context"
+ "database/sql"
+ "testing"
+ "testing/fstest"
+
+ sqlmock "github.com/DATA-DOG/go-sqlmock"
+ "github.com/stretchr/testify/require"
+)
+
+func TestValidateMigrationExecutionMode(t *testing.T) {
+ t.Run("事务迁移包含CONCURRENTLY会被拒绝", func(t *testing.T) {
+ nonTx, err := validateMigrationExecutionMode("001_add_idx.sql", "CREATE INDEX CONCURRENTLY idx_a ON t(a);")
+ require.False(t, nonTx)
+ require.Error(t, err)
+ })
+
+ t.Run("notx迁移要求CREATE使用IF NOT EXISTS", func(t *testing.T) {
+ nonTx, err := validateMigrationExecutionMode("001_add_idx_notx.sql", "CREATE INDEX CONCURRENTLY idx_a ON t(a);")
+ require.False(t, nonTx)
+ require.Error(t, err)
+ })
+
+ t.Run("notx迁移要求DROP使用IF EXISTS", func(t *testing.T) {
+ nonTx, err := validateMigrationExecutionMode("001_drop_idx_notx.sql", "DROP INDEX CONCURRENTLY idx_a;")
+ require.False(t, nonTx)
+ require.Error(t, err)
+ })
+
+ t.Run("notx迁移禁止事务控制语句", func(t *testing.T) {
+ nonTx, err := validateMigrationExecutionMode("001_add_idx_notx.sql", "BEGIN; CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_a ON t(a); COMMIT;")
+ require.False(t, nonTx)
+ require.Error(t, err)
+ })
+
+ t.Run("notx迁移禁止混用非CONCURRENTLY语句", func(t *testing.T) {
+ nonTx, err := validateMigrationExecutionMode("001_add_idx_notx.sql", "CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_a ON t(a); UPDATE t SET a = 1;")
+ require.False(t, nonTx)
+ require.Error(t, err)
+ })
+
+ t.Run("notx迁移允许幂等并发索引语句", func(t *testing.T) {
+ nonTx, err := validateMigrationExecutionMode("001_add_idx_notx.sql", `
+CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_a ON t(a);
+DROP INDEX CONCURRENTLY IF EXISTS idx_b;
+`)
+ require.True(t, nonTx)
+ require.NoError(t, err)
+ })
+}
+
+func TestApplyMigrationsFS_NonTransactionalMigration(t *testing.T) {
+ db, mock, err := sqlmock.New()
+ require.NoError(t, err)
+ defer func() { _ = db.Close() }()
+
+ prepareMigrationsBootstrapExpectations(mock)
+ mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1").
+ WithArgs("001_add_idx_notx.sql").
+ WillReturnError(sql.ErrNoRows)
+ mock.ExpectExec("CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_t_a ON t\\(a\\)").
+ WillReturnResult(sqlmock.NewResult(0, 0))
+ mock.ExpectExec("INSERT INTO schema_migrations \\(filename, checksum\\) VALUES \\(\\$1, \\$2\\)").
+ WithArgs("001_add_idx_notx.sql", sqlmock.AnyArg()).
+ WillReturnResult(sqlmock.NewResult(1, 1))
+ mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)").
+ WithArgs(migrationsAdvisoryLockID).
+ WillReturnResult(sqlmock.NewResult(0, 1))
+
+ fsys := fstest.MapFS{
+ "001_add_idx_notx.sql": &fstest.MapFile{
+ Data: []byte("CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_t_a ON t(a);"),
+ },
+ }
+
+ err = applyMigrationsFS(context.Background(), db, fsys)
+ require.NoError(t, err)
+ require.NoError(t, mock.ExpectationsWereMet())
+}
+
+func TestApplyMigrationsFS_NonTransactionalMigration_MultiStatements(t *testing.T) {
+ db, mock, err := sqlmock.New()
+ require.NoError(t, err)
+ defer func() { _ = db.Close() }()
+
+ prepareMigrationsBootstrapExpectations(mock)
+ mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1").
+ WithArgs("001_add_multi_idx_notx.sql").
+ WillReturnError(sql.ErrNoRows)
+ mock.ExpectExec("CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_t_a ON t\\(a\\)").
+ WillReturnResult(sqlmock.NewResult(0, 0))
+ mock.ExpectExec("CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_t_b ON t\\(b\\)").
+ WillReturnResult(sqlmock.NewResult(0, 0))
+ mock.ExpectExec("INSERT INTO schema_migrations \\(filename, checksum\\) VALUES \\(\\$1, \\$2\\)").
+ WithArgs("001_add_multi_idx_notx.sql", sqlmock.AnyArg()).
+ WillReturnResult(sqlmock.NewResult(1, 1))
+ mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)").
+ WithArgs(migrationsAdvisoryLockID).
+ WillReturnResult(sqlmock.NewResult(0, 1))
+
+ fsys := fstest.MapFS{
+ "001_add_multi_idx_notx.sql": &fstest.MapFile{
+ Data: []byte(`
+-- first
+CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_t_a ON t(a);
+-- second
+CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_t_b ON t(b);
+`),
+ },
+ }
+
+ err = applyMigrationsFS(context.Background(), db, fsys)
+ require.NoError(t, err)
+ require.NoError(t, mock.ExpectationsWereMet())
+}
+
+func TestApplyMigrationsFS_TransactionalMigration(t *testing.T) {
+ db, mock, err := sqlmock.New()
+ require.NoError(t, err)
+ defer func() { _ = db.Close() }()
+
+ prepareMigrationsBootstrapExpectations(mock)
+ mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1").
+ WithArgs("001_add_col.sql").
+ WillReturnError(sql.ErrNoRows)
+ mock.ExpectBegin()
+ mock.ExpectExec("ALTER TABLE t ADD COLUMN name TEXT").
+ WillReturnResult(sqlmock.NewResult(0, 0))
+ mock.ExpectExec("INSERT INTO schema_migrations \\(filename, checksum\\) VALUES \\(\\$1, \\$2\\)").
+ WithArgs("001_add_col.sql", sqlmock.AnyArg()).
+ WillReturnResult(sqlmock.NewResult(1, 1))
+ mock.ExpectCommit()
+ mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)").
+ WithArgs(migrationsAdvisoryLockID).
+ WillReturnResult(sqlmock.NewResult(0, 1))
+
+ fsys := fstest.MapFS{
+ "001_add_col.sql": &fstest.MapFile{
+ Data: []byte("ALTER TABLE t ADD COLUMN name TEXT;"),
+ },
+ }
+
+ err = applyMigrationsFS(context.Background(), db, fsys)
+ require.NoError(t, err)
+ require.NoError(t, mock.ExpectationsWereMet())
+}
+
+func prepareMigrationsBootstrapExpectations(mock sqlmock.Sqlmock) {
+ mock.ExpectQuery("SELECT pg_try_advisory_lock\\(\\$1\\)").
+ WithArgs(migrationsAdvisoryLockID).
+ WillReturnRows(sqlmock.NewRows([]string{"pg_try_advisory_lock"}).AddRow(true))
+ mock.ExpectExec("CREATE TABLE IF NOT EXISTS schema_migrations").
+ WillReturnResult(sqlmock.NewResult(0, 0))
+ mock.ExpectQuery("SELECT EXISTS \\(").
+ WithArgs("schema_migrations").
+ WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
+ mock.ExpectQuery("SELECT EXISTS \\(").
+ WithArgs("atlas_schema_revisions").
+ WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
+ mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM atlas_schema_revisions").
+ WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(1))
+}
diff --git a/backend/internal/repository/migrations_schema_integration_test.go b/backend/internal/repository/migrations_schema_integration_test.go
index f50d2b26..72422d18 100644
--- a/backend/internal/repository/migrations_schema_integration_test.go
+++ b/backend/internal/repository/migrations_schema_integration_test.go
@@ -42,6 +42,8 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) {
// usage_logs: billing_type used by filters/stats
requireColumn(t, tx, "usage_logs", "billing_type", "smallint", 0, false)
+ requireColumn(t, tx, "usage_logs", "request_type", "smallint", 0, false)
+ requireColumn(t, tx, "usage_logs", "openai_ws_mode", "boolean", 0, false)
// settings table should exist
var settingsRegclass sql.NullString
diff --git a/backend/internal/repository/openai_oauth_service.go b/backend/internal/repository/openai_oauth_service.go
index 088e7d7f..3e155971 100644
--- a/backend/internal/repository/openai_oauth_service.go
+++ b/backend/internal/repository/openai_oauth_service.go
@@ -22,16 +22,20 @@ type openaiOAuthService struct {
tokenURL string
}
-func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) {
+func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL, clientID string) (*openai.TokenResponse, error) {
client := createOpenAIReqClient(proxyURL)
if redirectURI == "" {
redirectURI = openai.DefaultRedirectURI
}
+ clientID = strings.TrimSpace(clientID)
+ if clientID == "" {
+ clientID = openai.ClientID
+ }
formData := url.Values{}
formData.Set("grant_type", "authorization_code")
- formData.Set("client_id", openai.ClientID)
+ formData.Set("client_id", clientID)
formData.Set("code", code)
formData.Set("redirect_uri", redirectURI)
formData.Set("code_verifier", codeVerifier)
@@ -61,36 +65,12 @@ func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, pro
}
func (s *openaiOAuthService) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) {
- if strings.TrimSpace(clientID) != "" {
- return s.refreshTokenWithClientID(ctx, refreshToken, proxyURL, strings.TrimSpace(clientID))
+ // 调用方应始终传入正确的 client_id;为兼容旧数据,未指定时默认使用 OpenAI ClientID
+ clientID = strings.TrimSpace(clientID)
+ if clientID == "" {
+ clientID = openai.ClientID
}
-
- clientIDs := []string{
- openai.ClientID,
- openai.SoraClientID,
- }
- seen := make(map[string]struct{}, len(clientIDs))
- var lastErr error
- for _, clientID := range clientIDs {
- clientID = strings.TrimSpace(clientID)
- if clientID == "" {
- continue
- }
- if _, ok := seen[clientID]; ok {
- continue
- }
- seen[clientID] = struct{}{}
-
- tokenResp, err := s.refreshTokenWithClientID(ctx, refreshToken, proxyURL, clientID)
- if err == nil {
- return tokenResp, nil
- }
- lastErr = err
- }
- if lastErr != nil {
- return nil, lastErr
- }
- return nil, infraerrors.New(http.StatusBadGateway, "OPENAI_OAUTH_TOKEN_REFRESH_FAILED", "token refresh failed")
+ return s.refreshTokenWithClientID(ctx, refreshToken, proxyURL, clientID)
}
func (s *openaiOAuthService) refreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL, clientID string) (*openai.TokenResponse, error) {
diff --git a/backend/internal/repository/openai_oauth_service_test.go b/backend/internal/repository/openai_oauth_service_test.go
index 5938272a..44fa291b 100644
--- a/backend/internal/repository/openai_oauth_service_test.go
+++ b/backend/internal/repository/openai_oauth_service_test.go
@@ -81,7 +81,7 @@ func (s *OpenAIOAuthServiceSuite) TestExchangeCode_DefaultRedirectURI() {
_, _ = io.WriteString(w, `{"access_token":"at","refresh_token":"rt","token_type":"bearer","expires_in":3600}`)
}))
- resp, err := s.svc.ExchangeCode(s.ctx, "code", "ver", "", "")
+ resp, err := s.svc.ExchangeCode(s.ctx, "code", "ver", "", "", "")
require.NoError(s.T(), err, "ExchangeCode")
select {
case msg := <-errCh:
@@ -136,7 +136,9 @@ func (s *OpenAIOAuthServiceSuite) TestRefreshToken_FormFields() {
require.Equal(s.T(), "rt2", resp.RefreshToken)
}
-func (s *OpenAIOAuthServiceSuite) TestRefreshToken_FallbackToSoraClientID() {
+// TestRefreshToken_DefaultsToOpenAIClientID 验证未指定 client_id 时默认使用 OpenAI ClientID,
+// 且只发送一次请求(不再盲猜多个 client_id)。
+func (s *OpenAIOAuthServiceSuite) TestRefreshToken_DefaultsToOpenAIClientID() {
var seenClientIDs []string
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
@@ -145,11 +147,27 @@ func (s *OpenAIOAuthServiceSuite) TestRefreshToken_FallbackToSoraClientID() {
}
clientID := r.PostForm.Get("client_id")
seenClientIDs = append(seenClientIDs, clientID)
- if clientID == openai.ClientID {
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = io.WriteString(w, `{"access_token":"at","refresh_token":"rt","token_type":"bearer","expires_in":3600}`)
+ }))
+
+ resp, err := s.svc.RefreshToken(s.ctx, "rt", "")
+ require.NoError(s.T(), err, "RefreshToken")
+ require.Equal(s.T(), "at", resp.AccessToken)
+ // 只发送了一次请求,使用默认的 OpenAI ClientID
+ require.Equal(s.T(), []string{openai.ClientID}, seenClientIDs)
+}
+
+// TestRefreshToken_UseSoraClientID 验证显式传入 Sora ClientID 时直接使用,不回退。
+func (s *OpenAIOAuthServiceSuite) TestRefreshToken_UseSoraClientID() {
+ var seenClientIDs []string
+ s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if err := r.ParseForm(); err != nil {
w.WriteHeader(http.StatusBadRequest)
- _, _ = io.WriteString(w, "invalid_grant")
return
}
+ clientID := r.PostForm.Get("client_id")
+ seenClientIDs = append(seenClientIDs, clientID)
if clientID == openai.SoraClientID {
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, `{"access_token":"at-sora","refresh_token":"rt-sora","token_type":"bearer","expires_in":3600}`)
@@ -158,11 +176,10 @@ func (s *OpenAIOAuthServiceSuite) TestRefreshToken_FallbackToSoraClientID() {
w.WriteHeader(http.StatusBadRequest)
}))
- resp, err := s.svc.RefreshToken(s.ctx, "rt", "")
- require.NoError(s.T(), err, "RefreshToken")
+ resp, err := s.svc.RefreshTokenWithClientID(s.ctx, "rt", "", openai.SoraClientID)
+ require.NoError(s.T(), err, "RefreshTokenWithClientID")
require.Equal(s.T(), "at-sora", resp.AccessToken)
- require.Equal(s.T(), "rt-sora", resp.RefreshToken)
- require.Equal(s.T(), []string{openai.ClientID, openai.SoraClientID}, seenClientIDs)
+ require.Equal(s.T(), []string{openai.SoraClientID}, seenClientIDs)
}
func (s *OpenAIOAuthServiceSuite) TestRefreshToken_UseProvidedClientID() {
@@ -196,7 +213,7 @@ func (s *OpenAIOAuthServiceSuite) TestNonSuccessStatus_IncludesBody() {
_, _ = io.WriteString(w, "bad")
}))
- _, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "")
+ _, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "", "")
require.Error(s.T(), err)
require.ErrorContains(s.T(), err, "status 400")
require.ErrorContains(s.T(), err, "bad")
@@ -206,7 +223,7 @@ func (s *OpenAIOAuthServiceSuite) TestRequestError_ClosedServer() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
s.srv.Close()
- _, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "")
+ _, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "", "")
require.Error(s.T(), err)
require.ErrorContains(s.T(), err, "request failed")
}
@@ -223,7 +240,7 @@ func (s *OpenAIOAuthServiceSuite) TestContextCancel() {
done := make(chan error, 1)
go func() {
- _, err := s.svc.ExchangeCode(ctx, "code", "ver", openai.DefaultRedirectURI, "")
+ _, err := s.svc.ExchangeCode(ctx, "code", "ver", openai.DefaultRedirectURI, "", "")
done <- err
}()
@@ -249,7 +266,30 @@ func (s *OpenAIOAuthServiceSuite) TestExchangeCode_UsesProvidedRedirectURI() {
_, _ = io.WriteString(w, `{"access_token":"at","token_type":"bearer","expires_in":1}`)
}))
- _, err := s.svc.ExchangeCode(s.ctx, "code", "ver", want, "")
+ _, err := s.svc.ExchangeCode(s.ctx, "code", "ver", want, "", "")
+ require.NoError(s.T(), err, "ExchangeCode")
+ select {
+ case msg := <-errCh:
+ require.Fail(s.T(), msg)
+ default:
+ }
+}
+
+func (s *OpenAIOAuthServiceSuite) TestExchangeCode_UseProvidedClientID() {
+ wantClientID := openai.SoraClientID
+ errCh := make(chan string, 1)
+ s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ _ = r.ParseForm()
+ if got := r.PostForm.Get("client_id"); got != wantClientID {
+ errCh <- "client_id mismatch"
+ w.WriteHeader(http.StatusBadRequest)
+ return
+ }
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = io.WriteString(w, `{"access_token":"at","token_type":"bearer","expires_in":1}`)
+ }))
+
+ _, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "", wantClientID)
require.NoError(s.T(), err, "ExchangeCode")
select {
case msg := <-errCh:
@@ -267,7 +307,7 @@ func (s *OpenAIOAuthServiceSuite) TestTokenURL_CanBeOverriddenWithQuery() {
}))
s.svc.tokenURL = s.srv.URL + "?x=1"
- _, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "")
+ _, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "", "")
require.NoError(s.T(), err, "ExchangeCode")
select {
case <-s.received:
@@ -283,7 +323,7 @@ func (s *OpenAIOAuthServiceSuite) TestExchangeCode_SuccessButInvalidJSON() {
_, _ = io.WriteString(w, "not-valid-json")
}))
- _, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "")
+ _, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "", "")
require.Error(s.T(), err, "expected error for invalid JSON response")
}
diff --git a/backend/internal/repository/ops_repo_dashboard.go b/backend/internal/repository/ops_repo_dashboard.go
index 85791a9a..b43d6706 100644
--- a/backend/internal/repository/ops_repo_dashboard.go
+++ b/backend/internal/repository/ops_repo_dashboard.go
@@ -12,6 +12,11 @@ import (
"github.com/Wei-Shaw/sub2api/internal/service"
)
+const (
+ opsRawLatencyQueryTimeout = 2 * time.Second
+ opsRawPeakQueryTimeout = 1500 * time.Millisecond
+)
+
func (r *opsRepository) GetDashboardOverview(ctx context.Context, filter *service.OpsDashboardFilter) (*service.OpsDashboardOverview, error) {
if r == nil || r.db == nil {
return nil, fmt.Errorf("nil ops repository")
@@ -45,15 +50,24 @@ func (r *opsRepository) GetDashboardOverview(ctx context.Context, filter *servic
func (r *opsRepository) getDashboardOverviewRaw(ctx context.Context, filter *service.OpsDashboardFilter) (*service.OpsDashboardOverview, error) {
start := filter.StartTime.UTC()
end := filter.EndTime.UTC()
+ degraded := false
successCount, tokenConsumed, err := r.queryUsageCounts(ctx, filter, start, end)
if err != nil {
return nil, err
}
- duration, ttft, err := r.queryUsageLatency(ctx, filter, start, end)
+ latencyCtx, cancelLatency := context.WithTimeout(ctx, opsRawLatencyQueryTimeout)
+ duration, ttft, err := r.queryUsageLatency(latencyCtx, filter, start, end)
+ cancelLatency()
if err != nil {
- return nil, err
+ if isQueryTimeoutErr(err) {
+ degraded = true
+ duration = service.OpsPercentiles{}
+ ttft = service.OpsPercentiles{}
+ } else {
+ return nil, err
+ }
}
errorTotal, businessLimited, errorCountSLA, upstreamExcl, upstream429, upstream529, err := r.queryErrorCounts(ctx, filter, start, end)
@@ -75,20 +89,40 @@ func (r *opsRepository) getDashboardOverviewRaw(ctx context.Context, filter *ser
qpsCurrent, tpsCurrent, err := r.queryCurrentRates(ctx, filter, end)
if err != nil {
- return nil, err
+ if isQueryTimeoutErr(err) {
+ degraded = true
+ } else {
+ return nil, err
+ }
}
- qpsPeak, err := r.queryPeakQPS(ctx, filter, start, end)
+ peakCtx, cancelPeak := context.WithTimeout(ctx, opsRawPeakQueryTimeout)
+ qpsPeak, tpsPeak, err := r.queryPeakRates(peakCtx, filter, start, end)
+ cancelPeak()
if err != nil {
- return nil, err
- }
- tpsPeak, err := r.queryPeakTPS(ctx, filter, start, end)
- if err != nil {
- return nil, err
+ if isQueryTimeoutErr(err) {
+ degraded = true
+ } else {
+ return nil, err
+ }
}
qpsAvg := roundTo1DP(float64(requestCountTotal) / windowSeconds)
tpsAvg := roundTo1DP(float64(tokenConsumed) / windowSeconds)
+ if degraded {
+ if qpsCurrent <= 0 {
+ qpsCurrent = qpsAvg
+ }
+ if tpsCurrent <= 0 {
+ tpsCurrent = tpsAvg
+ }
+ if qpsPeak <= 0 {
+ qpsPeak = roundTo1DP(math.Max(qpsCurrent, qpsAvg))
+ }
+ if tpsPeak <= 0 {
+ tpsPeak = roundTo1DP(math.Max(tpsCurrent, tpsAvg))
+ }
+ }
return &service.OpsDashboardOverview{
StartTime: start,
@@ -230,26 +264,45 @@ func (r *opsRepository) getDashboardOverviewPreaggregated(ctx context.Context, f
sla := safeDivideFloat64(float64(successCount), float64(requestCountSLA))
errorRate := safeDivideFloat64(float64(errorCountSLA), float64(requestCountSLA))
upstreamErrorRate := safeDivideFloat64(float64(upstreamExcl), float64(requestCountSLA))
+ degraded := false
// Keep "current" rates as raw, to preserve realtime semantics.
qpsCurrent, tpsCurrent, err := r.queryCurrentRates(ctx, filter, end)
if err != nil {
- return nil, err
+ if isQueryTimeoutErr(err) {
+ degraded = true
+ } else {
+ return nil, err
+ }
}
- // NOTE: peak still uses raw logs (minute granularity). This is typically cheaper than percentile_cont
- // and keeps semantics consistent across modes.
- qpsPeak, err := r.queryPeakQPS(ctx, filter, start, end)
+ peakCtx, cancelPeak := context.WithTimeout(ctx, opsRawPeakQueryTimeout)
+ qpsPeak, tpsPeak, err := r.queryPeakRates(peakCtx, filter, start, end)
+ cancelPeak()
if err != nil {
- return nil, err
- }
- tpsPeak, err := r.queryPeakTPS(ctx, filter, start, end)
- if err != nil {
- return nil, err
+ if isQueryTimeoutErr(err) {
+ degraded = true
+ } else {
+ return nil, err
+ }
}
qpsAvg := roundTo1DP(float64(requestCountTotal) / windowSeconds)
tpsAvg := roundTo1DP(float64(tokenConsumed) / windowSeconds)
+ if degraded {
+ if qpsCurrent <= 0 {
+ qpsCurrent = qpsAvg
+ }
+ if tpsCurrent <= 0 {
+ tpsCurrent = tpsAvg
+ }
+ if qpsPeak <= 0 {
+ qpsPeak = roundTo1DP(math.Max(qpsCurrent, qpsAvg))
+ }
+ if tpsPeak <= 0 {
+ tpsPeak = roundTo1DP(math.Max(tpsCurrent, tpsAvg))
+ }
+ }
return &service.OpsDashboardOverview{
StartTime: start,
@@ -577,9 +630,16 @@ func (r *opsRepository) queryRawPartial(ctx context.Context, filter *service.Ops
return nil, err
}
- duration, ttft, err := r.queryUsageLatency(ctx, filter, start, end)
+ latencyCtx, cancelLatency := context.WithTimeout(ctx, opsRawLatencyQueryTimeout)
+ duration, ttft, err := r.queryUsageLatency(latencyCtx, filter, start, end)
+ cancelLatency()
if err != nil {
- return nil, err
+ if isQueryTimeoutErr(err) {
+ duration = service.OpsPercentiles{}
+ ttft = service.OpsPercentiles{}
+ } else {
+ return nil, err
+ }
}
errorTotal, businessLimited, errorCountSLA, upstreamExcl, upstream429, upstream529, err := r.queryErrorCounts(ctx, filter, start, end)
@@ -735,68 +795,56 @@ FROM usage_logs ul
}
func (r *opsRepository) queryUsageLatency(ctx context.Context, filter *service.OpsDashboardFilter, start, end time.Time) (duration service.OpsPercentiles, ttft service.OpsPercentiles, err error) {
- {
- join, where, args, _ := buildUsageWhere(filter, start, end, 1)
- q := `
+ join, where, args, _ := buildUsageWhere(filter, start, end, 1)
+ q := `
SELECT
- percentile_cont(0.50) WITHIN GROUP (ORDER BY duration_ms) AS p50,
- percentile_cont(0.90) WITHIN GROUP (ORDER BY duration_ms) AS p90,
- percentile_cont(0.95) WITHIN GROUP (ORDER BY duration_ms) AS p95,
- percentile_cont(0.99) WITHIN GROUP (ORDER BY duration_ms) AS p99,
- AVG(duration_ms) AS avg_ms,
- MAX(duration_ms) AS max_ms
+ percentile_cont(0.50) WITHIN GROUP (ORDER BY duration_ms) FILTER (WHERE duration_ms IS NOT NULL) AS duration_p50,
+ percentile_cont(0.90) WITHIN GROUP (ORDER BY duration_ms) FILTER (WHERE duration_ms IS NOT NULL) AS duration_p90,
+ percentile_cont(0.95) WITHIN GROUP (ORDER BY duration_ms) FILTER (WHERE duration_ms IS NOT NULL) AS duration_p95,
+ percentile_cont(0.99) WITHIN GROUP (ORDER BY duration_ms) FILTER (WHERE duration_ms IS NOT NULL) AS duration_p99,
+ AVG(duration_ms) FILTER (WHERE duration_ms IS NOT NULL) AS duration_avg,
+ MAX(duration_ms) AS duration_max,
+ percentile_cont(0.50) WITHIN GROUP (ORDER BY first_token_ms) FILTER (WHERE first_token_ms IS NOT NULL) AS ttft_p50,
+ percentile_cont(0.90) WITHIN GROUP (ORDER BY first_token_ms) FILTER (WHERE first_token_ms IS NOT NULL) AS ttft_p90,
+ percentile_cont(0.95) WITHIN GROUP (ORDER BY first_token_ms) FILTER (WHERE first_token_ms IS NOT NULL) AS ttft_p95,
+ percentile_cont(0.99) WITHIN GROUP (ORDER BY first_token_ms) FILTER (WHERE first_token_ms IS NOT NULL) AS ttft_p99,
+ AVG(first_token_ms) FILTER (WHERE first_token_ms IS NOT NULL) AS ttft_avg,
+ MAX(first_token_ms) AS ttft_max
FROM usage_logs ul
` + join + `
-` + where + `
-AND duration_ms IS NOT NULL`
+` + where
- var p50, p90, p95, p99 sql.NullFloat64
- var avg sql.NullFloat64
- var max sql.NullInt64
- if err := r.db.QueryRowContext(ctx, q, args...).Scan(&p50, &p90, &p95, &p99, &avg, &max); err != nil {
- return service.OpsPercentiles{}, service.OpsPercentiles{}, err
- }
- duration.P50 = floatToIntPtr(p50)
- duration.P90 = floatToIntPtr(p90)
- duration.P95 = floatToIntPtr(p95)
- duration.P99 = floatToIntPtr(p99)
- duration.Avg = floatToIntPtr(avg)
- if max.Valid {
- v := int(max.Int64)
- duration.Max = &v
- }
+ var dP50, dP90, dP95, dP99 sql.NullFloat64
+ var dAvg sql.NullFloat64
+ var dMax sql.NullInt64
+ var tP50, tP90, tP95, tP99 sql.NullFloat64
+ var tAvg sql.NullFloat64
+ var tMax sql.NullInt64
+ if err := r.db.QueryRowContext(ctx, q, args...).Scan(
+ &dP50, &dP90, &dP95, &dP99, &dAvg, &dMax,
+ &tP50, &tP90, &tP95, &tP99, &tAvg, &tMax,
+ ); err != nil {
+ return service.OpsPercentiles{}, service.OpsPercentiles{}, err
}
- {
- join, where, args, _ := buildUsageWhere(filter, start, end, 1)
- q := `
-SELECT
- percentile_cont(0.50) WITHIN GROUP (ORDER BY first_token_ms) AS p50,
- percentile_cont(0.90) WITHIN GROUP (ORDER BY first_token_ms) AS p90,
- percentile_cont(0.95) WITHIN GROUP (ORDER BY first_token_ms) AS p95,
- percentile_cont(0.99) WITHIN GROUP (ORDER BY first_token_ms) AS p99,
- AVG(first_token_ms) AS avg_ms,
- MAX(first_token_ms) AS max_ms
-FROM usage_logs ul
-` + join + `
-` + where + `
-AND first_token_ms IS NOT NULL`
+ duration.P50 = floatToIntPtr(dP50)
+ duration.P90 = floatToIntPtr(dP90)
+ duration.P95 = floatToIntPtr(dP95)
+ duration.P99 = floatToIntPtr(dP99)
+ duration.Avg = floatToIntPtr(dAvg)
+ if dMax.Valid {
+ v := int(dMax.Int64)
+ duration.Max = &v
+ }
- var p50, p90, p95, p99 sql.NullFloat64
- var avg sql.NullFloat64
- var max sql.NullInt64
- if err := r.db.QueryRowContext(ctx, q, args...).Scan(&p50, &p90, &p95, &p99, &avg, &max); err != nil {
- return service.OpsPercentiles{}, service.OpsPercentiles{}, err
- }
- ttft.P50 = floatToIntPtr(p50)
- ttft.P90 = floatToIntPtr(p90)
- ttft.P95 = floatToIntPtr(p95)
- ttft.P99 = floatToIntPtr(p99)
- ttft.Avg = floatToIntPtr(avg)
- if max.Valid {
- v := int(max.Int64)
- ttft.Max = &v
- }
+ ttft.P50 = floatToIntPtr(tP50)
+ ttft.P90 = floatToIntPtr(tP90)
+ ttft.P95 = floatToIntPtr(tP95)
+ ttft.P99 = floatToIntPtr(tP99)
+ ttft.Avg = floatToIntPtr(tAvg)
+ if tMax.Valid {
+ v := int(tMax.Int64)
+ ttft.Max = &v
}
return duration, ttft, nil
@@ -854,20 +902,23 @@ func (r *opsRepository) queryCurrentRates(ctx context.Context, filter *service.O
return qpsCurrent, tpsCurrent, nil
}
-func (r *opsRepository) queryPeakQPS(ctx context.Context, filter *service.OpsDashboardFilter, start, end time.Time) (float64, error) {
+func (r *opsRepository) queryPeakRates(ctx context.Context, filter *service.OpsDashboardFilter, start, end time.Time) (qpsPeak float64, tpsPeak float64, err error) {
usageJoin, usageWhere, usageArgs, next := buildUsageWhere(filter, start, end, 1)
errorWhere, errorArgs, _ := buildErrorWhere(filter, start, end, next)
q := `
WITH usage_buckets AS (
- SELECT date_trunc('minute', ul.created_at) AS bucket, COUNT(*) AS cnt
+ SELECT
+ date_trunc('minute', ul.created_at) AS bucket,
+ COUNT(*) AS req_cnt,
+ COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) AS token_cnt
FROM usage_logs ul
` + usageJoin + `
` + usageWhere + `
GROUP BY 1
),
error_buckets AS (
- SELECT date_trunc('minute', created_at) AS bucket, COUNT(*) AS cnt
+ SELECT date_trunc('minute', created_at) AS bucket, COUNT(*) AS err_cnt
FROM ops_error_logs
` + errorWhere + `
AND COALESCE(status_code, 0) >= 400
@@ -875,47 +926,33 @@ error_buckets AS (
),
combined AS (
SELECT COALESCE(u.bucket, e.bucket) AS bucket,
- COALESCE(u.cnt, 0) + COALESCE(e.cnt, 0) AS total
+ COALESCE(u.req_cnt, 0) + COALESCE(e.err_cnt, 0) AS total_req,
+ COALESCE(u.token_cnt, 0) AS total_tokens
FROM usage_buckets u
FULL OUTER JOIN error_buckets e ON u.bucket = e.bucket
)
-SELECT COALESCE(MAX(total), 0) FROM combined`
+SELECT
+ COALESCE(MAX(total_req), 0) AS max_req_per_min,
+ COALESCE(MAX(total_tokens), 0) AS max_tokens_per_min
+FROM combined`
args := append(usageArgs, errorArgs...)
- var maxPerMinute sql.NullInt64
- if err := r.db.QueryRowContext(ctx, q, args...).Scan(&maxPerMinute); err != nil {
- return 0, err
+ var maxReqPerMinute, maxTokensPerMinute sql.NullInt64
+ if err := r.db.QueryRowContext(ctx, q, args...).Scan(&maxReqPerMinute, &maxTokensPerMinute); err != nil {
+ return 0, 0, err
}
- if !maxPerMinute.Valid || maxPerMinute.Int64 <= 0 {
- return 0, nil
+ if maxReqPerMinute.Valid && maxReqPerMinute.Int64 > 0 {
+ qpsPeak = roundTo1DP(float64(maxReqPerMinute.Int64) / 60.0)
}
- return roundTo1DP(float64(maxPerMinute.Int64) / 60.0), nil
+ if maxTokensPerMinute.Valid && maxTokensPerMinute.Int64 > 0 {
+ tpsPeak = roundTo1DP(float64(maxTokensPerMinute.Int64) / 60.0)
+ }
+ return qpsPeak, tpsPeak, nil
}
-func (r *opsRepository) queryPeakTPS(ctx context.Context, filter *service.OpsDashboardFilter, start, end time.Time) (float64, error) {
- join, where, args, _ := buildUsageWhere(filter, start, end, 1)
-
- q := `
-SELECT COALESCE(MAX(tokens_per_min), 0)
-FROM (
- SELECT
- date_trunc('minute', ul.created_at) AS bucket,
- COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) AS tokens_per_min
- FROM usage_logs ul
- ` + join + `
- ` + where + `
- GROUP BY 1
-) t`
-
- var maxPerMinute sql.NullInt64
- if err := r.db.QueryRowContext(ctx, q, args...).Scan(&maxPerMinute); err != nil {
- return 0, err
- }
- if !maxPerMinute.Valid || maxPerMinute.Int64 <= 0 {
- return 0, nil
- }
- return roundTo1DP(float64(maxPerMinute.Int64) / 60.0), nil
+func isQueryTimeoutErr(err error) bool {
+ return errors.Is(err, context.DeadlineExceeded)
}
func buildUsageWhere(filter *service.OpsDashboardFilter, start, end time.Time, startIndex int) (join string, where string, args []any, nextIndex int) {
diff --git a/backend/internal/repository/ops_repo_dashboard_timeout_test.go b/backend/internal/repository/ops_repo_dashboard_timeout_test.go
new file mode 100644
index 00000000..76332ca0
--- /dev/null
+++ b/backend/internal/repository/ops_repo_dashboard_timeout_test.go
@@ -0,0 +1,22 @@
+package repository
+
+import (
+ "context"
+ "fmt"
+ "testing"
+)
+
+func TestIsQueryTimeoutErr(t *testing.T) {
+ if !isQueryTimeoutErr(context.DeadlineExceeded) {
+ t.Fatalf("context.DeadlineExceeded should be treated as query timeout")
+ }
+ if !isQueryTimeoutErr(fmt.Errorf("wrapped: %w", context.DeadlineExceeded)) {
+ t.Fatalf("wrapped context.DeadlineExceeded should be treated as query timeout")
+ }
+ if isQueryTimeoutErr(context.Canceled) {
+ t.Fatalf("context.Canceled should not be treated as query timeout")
+ }
+ if isQueryTimeoutErr(fmt.Errorf("wrapped: %w", context.Canceled)) {
+ t.Fatalf("wrapped context.Canceled should not be treated as query timeout")
+ }
+}
diff --git a/backend/internal/repository/rpm_cache.go b/backend/internal/repository/rpm_cache.go
new file mode 100644
index 00000000..4d73ec4b
--- /dev/null
+++ b/backend/internal/repository/rpm_cache.go
@@ -0,0 +1,141 @@
+package repository
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "strconv"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/redis/go-redis/v9"
+)
+
+// RPM 计数器缓存常量定义
+//
+// 设计说明:
+// 使用 Redis 简单计数器跟踪每个账号每分钟的请求数:
+// - Key: rpm:{accountID}:{minuteTimestamp}
+// - Value: 当前分钟内的请求计数
+// - TTL: 120 秒(覆盖当前分钟 + 一定冗余)
+//
+// 使用 TxPipeline(MULTI/EXEC)执行 INCR + EXPIRE,保证原子性且兼容 Redis Cluster。
+// 通过 rdb.Time() 获取服务端时间,避免多实例时钟不同步。
+//
+// 设计决策:
+// - TxPipeline vs Pipeline:Pipeline 仅合并发送但不保证原子,TxPipeline 使用 MULTI/EXEC 事务保证原子执行。
+// - rdb.Time() 单独调用:Pipeline/TxPipeline 中无法引用前一命令的结果,因此 TIME 必须单独调用(2 RTT)。
+// Lua 脚本可以做到 1 RTT,但在 Redis Cluster 中动态拼接 key 存在 CROSSSLOT 风险,选择安全性优先。
+const (
+ // RPM 计数器键前缀
+ // 格式: rpm:{accountID}:{minuteTimestamp}
+ rpmKeyPrefix = "rpm:"
+
+ // RPM 计数器 TTL(120 秒,覆盖当前分钟窗口 + 冗余)
+ rpmKeyTTL = 120 * time.Second
+)
+
+// RPMCacheImpl RPM 计数器缓存 Redis 实现
+type RPMCacheImpl struct {
+ rdb *redis.Client
+}
+
+// NewRPMCache 创建 RPM 计数器缓存
+func NewRPMCache(rdb *redis.Client) service.RPMCache {
+ return &RPMCacheImpl{rdb: rdb}
+}
+
+// currentMinuteKey 获取当前分钟的完整 Redis key
+// 使用 rdb.Time() 获取 Redis 服务端时间,避免多实例时钟偏差
+func (c *RPMCacheImpl) currentMinuteKey(ctx context.Context, accountID int64) (string, error) {
+ serverTime, err := c.rdb.Time(ctx).Result()
+ if err != nil {
+ return "", fmt.Errorf("redis TIME: %w", err)
+ }
+ minuteTS := serverTime.Unix() / 60
+ return fmt.Sprintf("%s%d:%d", rpmKeyPrefix, accountID, minuteTS), nil
+}
+
+// currentMinuteSuffix 获取当前分钟时间戳后缀(供批量操作使用)
+// 使用 rdb.Time() 获取 Redis 服务端时间
+func (c *RPMCacheImpl) currentMinuteSuffix(ctx context.Context) (string, error) {
+ serverTime, err := c.rdb.Time(ctx).Result()
+ if err != nil {
+ return "", fmt.Errorf("redis TIME: %w", err)
+ }
+ minuteTS := serverTime.Unix() / 60
+ return strconv.FormatInt(minuteTS, 10), nil
+}
+
+// IncrementRPM 原子递增并返回当前分钟的计数
+// 使用 TxPipeline (MULTI/EXEC) 执行 INCR + EXPIRE,保证原子性且兼容 Redis Cluster
+func (c *RPMCacheImpl) IncrementRPM(ctx context.Context, accountID int64) (int, error) {
+ key, err := c.currentMinuteKey(ctx, accountID)
+ if err != nil {
+ return 0, fmt.Errorf("rpm increment: %w", err)
+ }
+
+ // 使用 TxPipeline (MULTI/EXEC) 保证 INCR + EXPIRE 原子执行
+ // EXPIRE 幂等,每次都设置不影响正确性
+ pipe := c.rdb.TxPipeline()
+ incrCmd := pipe.Incr(ctx, key)
+ pipe.Expire(ctx, key, rpmKeyTTL)
+
+ if _, err := pipe.Exec(ctx); err != nil {
+ return 0, fmt.Errorf("rpm increment: %w", err)
+ }
+
+ return int(incrCmd.Val()), nil
+}
+
+// GetRPM 获取当前分钟的 RPM 计数
+func (c *RPMCacheImpl) GetRPM(ctx context.Context, accountID int64) (int, error) {
+ key, err := c.currentMinuteKey(ctx, accountID)
+ if err != nil {
+ return 0, fmt.Errorf("rpm get: %w", err)
+ }
+
+ val, err := c.rdb.Get(ctx, key).Int()
+ if errors.Is(err, redis.Nil) {
+ return 0, nil // 当前分钟无记录
+ }
+ if err != nil {
+ return 0, fmt.Errorf("rpm get: %w", err)
+ }
+ return val, nil
+}
+
+// GetRPMBatch 批量获取多个账号的 RPM 计数(使用 Pipeline)
+func (c *RPMCacheImpl) GetRPMBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {
+ if len(accountIDs) == 0 {
+ return map[int64]int{}, nil
+ }
+
+ // 获取当前分钟后缀
+ minuteSuffix, err := c.currentMinuteSuffix(ctx)
+ if err != nil {
+ return nil, fmt.Errorf("rpm batch get: %w", err)
+ }
+
+ // 使用 Pipeline 批量 GET
+ pipe := c.rdb.Pipeline()
+ cmds := make(map[int64]*redis.StringCmd, len(accountIDs))
+ for _, id := range accountIDs {
+ key := fmt.Sprintf("%s%d:%s", rpmKeyPrefix, id, minuteSuffix)
+ cmds[id] = pipe.Get(ctx, key)
+ }
+
+ if _, err := pipe.Exec(ctx); err != nil && !errors.Is(err, redis.Nil) {
+ return nil, fmt.Errorf("rpm batch get: %w", err)
+ }
+
+ result := make(map[int64]int, len(accountIDs))
+ for id, cmd := range cmds {
+ if val, err := cmd.Int(); err == nil {
+ result[id] = val
+ } else {
+ result[id] = 0
+ }
+ }
+ return result, nil
+}
diff --git a/backend/internal/repository/sora_generation_repo.go b/backend/internal/repository/sora_generation_repo.go
new file mode 100644
index 00000000..aaf3cb2f
--- /dev/null
+++ b/backend/internal/repository/sora_generation_repo.go
@@ -0,0 +1,419 @@
+package repository
+
+import (
+ "context"
+ "database/sql"
+ "encoding/json"
+ "fmt"
+ "strings"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/service"
+)
+
+// soraGenerationRepository 实现 service.SoraGenerationRepository 接口。
+// 使用原生 SQL 操作 sora_generations 表。
+type soraGenerationRepository struct {
+ sql *sql.DB
+}
+
+// NewSoraGenerationRepository 创建 Sora 生成记录仓储实例。
+func NewSoraGenerationRepository(sqlDB *sql.DB) service.SoraGenerationRepository {
+ return &soraGenerationRepository{sql: sqlDB}
+}
+
+func (r *soraGenerationRepository) Create(ctx context.Context, gen *service.SoraGeneration) error {
+ mediaURLsJSON, _ := json.Marshal(gen.MediaURLs)
+ s3KeysJSON, _ := json.Marshal(gen.S3ObjectKeys)
+
+ err := r.sql.QueryRowContext(ctx, `
+ INSERT INTO sora_generations (
+ user_id, api_key_id, model, prompt, media_type,
+ status, media_url, media_urls, file_size_bytes,
+ storage_type, s3_object_keys, upstream_task_id, error_message
+ ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
+ RETURNING id, created_at
+ `,
+ gen.UserID, gen.APIKeyID, gen.Model, gen.Prompt, gen.MediaType,
+ gen.Status, gen.MediaURL, mediaURLsJSON, gen.FileSizeBytes,
+ gen.StorageType, s3KeysJSON, gen.UpstreamTaskID, gen.ErrorMessage,
+ ).Scan(&gen.ID, &gen.CreatedAt)
+ return err
+}
+
+// CreatePendingWithLimit 在单事务内执行“并发上限检查 + 创建”,避免 count+create 竞态。
+func (r *soraGenerationRepository) CreatePendingWithLimit(
+ ctx context.Context,
+ gen *service.SoraGeneration,
+ activeStatuses []string,
+ maxActive int64,
+) error {
+ if gen == nil {
+ return fmt.Errorf("generation is nil")
+ }
+ if maxActive <= 0 {
+ return r.Create(ctx, gen)
+ }
+ if len(activeStatuses) == 0 {
+ activeStatuses = []string{service.SoraGenStatusPending, service.SoraGenStatusGenerating}
+ }
+
+ tx, err := r.sql.BeginTx(ctx, nil)
+ if err != nil {
+ return err
+ }
+ defer func() { _ = tx.Rollback() }()
+
+ // 使用用户级 advisory lock 串行化并发创建,避免超限竞态。
+ if _, err := tx.ExecContext(ctx, `SELECT pg_advisory_xact_lock($1)`, gen.UserID); err != nil {
+ return err
+ }
+
+ placeholders := make([]string, len(activeStatuses))
+ args := make([]any, 0, 1+len(activeStatuses))
+ args = append(args, gen.UserID)
+ for i, s := range activeStatuses {
+ placeholders[i] = fmt.Sprintf("$%d", i+2)
+ args = append(args, s)
+ }
+ countQuery := fmt.Sprintf(
+ `SELECT COUNT(*) FROM sora_generations WHERE user_id = $1 AND status IN (%s)`,
+ strings.Join(placeholders, ","),
+ )
+ var activeCount int64
+ if err := tx.QueryRowContext(ctx, countQuery, args...).Scan(&activeCount); err != nil {
+ return err
+ }
+ if activeCount >= maxActive {
+ return service.ErrSoraGenerationConcurrencyLimit
+ }
+
+ mediaURLsJSON, _ := json.Marshal(gen.MediaURLs)
+ s3KeysJSON, _ := json.Marshal(gen.S3ObjectKeys)
+ if err := tx.QueryRowContext(ctx, `
+ INSERT INTO sora_generations (
+ user_id, api_key_id, model, prompt, media_type,
+ status, media_url, media_urls, file_size_bytes,
+ storage_type, s3_object_keys, upstream_task_id, error_message
+ ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
+ RETURNING id, created_at
+ `,
+ gen.UserID, gen.APIKeyID, gen.Model, gen.Prompt, gen.MediaType,
+ gen.Status, gen.MediaURL, mediaURLsJSON, gen.FileSizeBytes,
+ gen.StorageType, s3KeysJSON, gen.UpstreamTaskID, gen.ErrorMessage,
+ ).Scan(&gen.ID, &gen.CreatedAt); err != nil {
+ return err
+ }
+
+ return tx.Commit()
+}
+
+func (r *soraGenerationRepository) GetByID(ctx context.Context, id int64) (*service.SoraGeneration, error) {
+ gen := &service.SoraGeneration{}
+ var mediaURLsJSON, s3KeysJSON []byte
+ var completedAt sql.NullTime
+ var apiKeyID sql.NullInt64
+
+ err := r.sql.QueryRowContext(ctx, `
+ SELECT id, user_id, api_key_id, model, prompt, media_type,
+ status, media_url, media_urls, file_size_bytes,
+ storage_type, s3_object_keys, upstream_task_id, error_message,
+ created_at, completed_at
+ FROM sora_generations WHERE id = $1
+ `, id).Scan(
+ &gen.ID, &gen.UserID, &apiKeyID, &gen.Model, &gen.Prompt, &gen.MediaType,
+ &gen.Status, &gen.MediaURL, &mediaURLsJSON, &gen.FileSizeBytes,
+ &gen.StorageType, &s3KeysJSON, &gen.UpstreamTaskID, &gen.ErrorMessage,
+ &gen.CreatedAt, &completedAt,
+ )
+ if err != nil {
+ if err == sql.ErrNoRows {
+ return nil, fmt.Errorf("生成记录不存在")
+ }
+ return nil, err
+ }
+
+ if apiKeyID.Valid {
+ gen.APIKeyID = &apiKeyID.Int64
+ }
+ if completedAt.Valid {
+ gen.CompletedAt = &completedAt.Time
+ }
+ _ = json.Unmarshal(mediaURLsJSON, &gen.MediaURLs)
+ _ = json.Unmarshal(s3KeysJSON, &gen.S3ObjectKeys)
+ return gen, nil
+}
+
+func (r *soraGenerationRepository) Update(ctx context.Context, gen *service.SoraGeneration) error {
+ mediaURLsJSON, _ := json.Marshal(gen.MediaURLs)
+ s3KeysJSON, _ := json.Marshal(gen.S3ObjectKeys)
+
+ var completedAt *time.Time
+ if gen.CompletedAt != nil {
+ completedAt = gen.CompletedAt
+ }
+
+ _, err := r.sql.ExecContext(ctx, `
+ UPDATE sora_generations SET
+ status = $2, media_url = $3, media_urls = $4, file_size_bytes = $5,
+ storage_type = $6, s3_object_keys = $7, upstream_task_id = $8,
+ error_message = $9, completed_at = $10
+ WHERE id = $1
+ `,
+ gen.ID, gen.Status, gen.MediaURL, mediaURLsJSON, gen.FileSizeBytes,
+ gen.StorageType, s3KeysJSON, gen.UpstreamTaskID,
+ gen.ErrorMessage, completedAt,
+ )
+ return err
+}
+
+// UpdateGeneratingIfPending 仅当状态为 pending 时更新为 generating。
+func (r *soraGenerationRepository) UpdateGeneratingIfPending(ctx context.Context, id int64, upstreamTaskID string) (bool, error) {
+ result, err := r.sql.ExecContext(ctx, `
+ UPDATE sora_generations
+ SET status = $2, upstream_task_id = $3
+ WHERE id = $1 AND status = $4
+ `,
+ id, service.SoraGenStatusGenerating, upstreamTaskID, service.SoraGenStatusPending,
+ )
+ if err != nil {
+ return false, err
+ }
+ affected, err := result.RowsAffected()
+ if err != nil {
+ return false, err
+ }
+ return affected > 0, nil
+}
+
+// UpdateCompletedIfActive 仅当状态为 pending/generating 时更新为 completed。
+func (r *soraGenerationRepository) UpdateCompletedIfActive(
+ ctx context.Context,
+ id int64,
+ mediaURL string,
+ mediaURLs []string,
+ storageType string,
+ s3Keys []string,
+ fileSizeBytes int64,
+ completedAt time.Time,
+) (bool, error) {
+ mediaURLsJSON, _ := json.Marshal(mediaURLs)
+ s3KeysJSON, _ := json.Marshal(s3Keys)
+ result, err := r.sql.ExecContext(ctx, `
+ UPDATE sora_generations
+ SET status = $2,
+ media_url = $3,
+ media_urls = $4,
+ file_size_bytes = $5,
+ storage_type = $6,
+ s3_object_keys = $7,
+ error_message = '',
+ completed_at = $8
+ WHERE id = $1 AND status IN ($9, $10)
+ `,
+ id, service.SoraGenStatusCompleted, mediaURL, mediaURLsJSON, fileSizeBytes,
+ storageType, s3KeysJSON, completedAt, service.SoraGenStatusPending, service.SoraGenStatusGenerating,
+ )
+ if err != nil {
+ return false, err
+ }
+ affected, err := result.RowsAffected()
+ if err != nil {
+ return false, err
+ }
+ return affected > 0, nil
+}
+
+// UpdateFailedIfActive 仅当状态为 pending/generating 时更新为 failed。
+func (r *soraGenerationRepository) UpdateFailedIfActive(
+ ctx context.Context,
+ id int64,
+ errMsg string,
+ completedAt time.Time,
+) (bool, error) {
+ result, err := r.sql.ExecContext(ctx, `
+ UPDATE sora_generations
+ SET status = $2,
+ error_message = $3,
+ completed_at = $4
+ WHERE id = $1 AND status IN ($5, $6)
+ `,
+ id, service.SoraGenStatusFailed, errMsg, completedAt, service.SoraGenStatusPending, service.SoraGenStatusGenerating,
+ )
+ if err != nil {
+ return false, err
+ }
+ affected, err := result.RowsAffected()
+ if err != nil {
+ return false, err
+ }
+ return affected > 0, nil
+}
+
+// UpdateCancelledIfActive 仅当状态为 pending/generating 时更新为 cancelled。
+func (r *soraGenerationRepository) UpdateCancelledIfActive(ctx context.Context, id int64, completedAt time.Time) (bool, error) {
+ result, err := r.sql.ExecContext(ctx, `
+ UPDATE sora_generations
+ SET status = $2, completed_at = $3
+ WHERE id = $1 AND status IN ($4, $5)
+ `,
+ id, service.SoraGenStatusCancelled, completedAt, service.SoraGenStatusPending, service.SoraGenStatusGenerating,
+ )
+ if err != nil {
+ return false, err
+ }
+ affected, err := result.RowsAffected()
+ if err != nil {
+ return false, err
+ }
+ return affected > 0, nil
+}
+
+// UpdateStorageIfCompleted 更新已完成记录的存储信息(用于手动保存,不重置 completed_at)。
+func (r *soraGenerationRepository) UpdateStorageIfCompleted(
+ ctx context.Context,
+ id int64,
+ mediaURL string,
+ mediaURLs []string,
+ storageType string,
+ s3Keys []string,
+ fileSizeBytes int64,
+) (bool, error) {
+ mediaURLsJSON, _ := json.Marshal(mediaURLs)
+ s3KeysJSON, _ := json.Marshal(s3Keys)
+ result, err := r.sql.ExecContext(ctx, `
+ UPDATE sora_generations
+ SET media_url = $2,
+ media_urls = $3,
+ file_size_bytes = $4,
+ storage_type = $5,
+ s3_object_keys = $6
+ WHERE id = $1 AND status = $7
+ `,
+ id, mediaURL, mediaURLsJSON, fileSizeBytes, storageType, s3KeysJSON, service.SoraGenStatusCompleted,
+ )
+ if err != nil {
+ return false, err
+ }
+ affected, err := result.RowsAffected()
+ if err != nil {
+ return false, err
+ }
+ return affected > 0, nil
+}
+
+func (r *soraGenerationRepository) Delete(ctx context.Context, id int64) error {
+ _, err := r.sql.ExecContext(ctx, `DELETE FROM sora_generations WHERE id = $1`, id)
+ return err
+}
+
+func (r *soraGenerationRepository) List(ctx context.Context, params service.SoraGenerationListParams) ([]*service.SoraGeneration, int64, error) {
+ // 构建 WHERE 条件
+ conditions := []string{"user_id = $1"}
+ args := []any{params.UserID}
+ argIdx := 2
+
+ if params.Status != "" {
+ // 支持逗号分隔的多状态
+ statuses := strings.Split(params.Status, ",")
+ placeholders := make([]string, len(statuses))
+ for i, s := range statuses {
+ placeholders[i] = fmt.Sprintf("$%d", argIdx)
+ args = append(args, strings.TrimSpace(s))
+ argIdx++
+ }
+ conditions = append(conditions, fmt.Sprintf("status IN (%s)", strings.Join(placeholders, ",")))
+ }
+ if params.StorageType != "" {
+ storageTypes := strings.Split(params.StorageType, ",")
+ placeholders := make([]string, len(storageTypes))
+ for i, s := range storageTypes {
+ placeholders[i] = fmt.Sprintf("$%d", argIdx)
+ args = append(args, strings.TrimSpace(s))
+ argIdx++
+ }
+ conditions = append(conditions, fmt.Sprintf("storage_type IN (%s)", strings.Join(placeholders, ",")))
+ }
+ if params.MediaType != "" {
+ conditions = append(conditions, fmt.Sprintf("media_type = $%d", argIdx))
+ args = append(args, params.MediaType)
+ argIdx++
+ }
+
+ whereClause := "WHERE " + strings.Join(conditions, " AND ")
+
+ // 计数
+ var total int64
+ countQuery := fmt.Sprintf("SELECT COUNT(*) FROM sora_generations %s", whereClause)
+ if err := r.sql.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil {
+ return nil, 0, err
+ }
+
+ // 分页查询
+ offset := (params.Page - 1) * params.PageSize
+ listQuery := fmt.Sprintf(`
+ SELECT id, user_id, api_key_id, model, prompt, media_type,
+ status, media_url, media_urls, file_size_bytes,
+ storage_type, s3_object_keys, upstream_task_id, error_message,
+ created_at, completed_at
+ FROM sora_generations %s
+ ORDER BY created_at DESC
+ LIMIT $%d OFFSET $%d
+ `, whereClause, argIdx, argIdx+1)
+ args = append(args, params.PageSize, offset)
+
+ rows, err := r.sql.QueryContext(ctx, listQuery, args...)
+ if err != nil {
+ return nil, 0, err
+ }
+ defer func() {
+ _ = rows.Close()
+ }()
+
+ var results []*service.SoraGeneration
+ for rows.Next() {
+ gen := &service.SoraGeneration{}
+ var mediaURLsJSON, s3KeysJSON []byte
+ var completedAt sql.NullTime
+ var apiKeyID sql.NullInt64
+
+ if err := rows.Scan(
+ &gen.ID, &gen.UserID, &apiKeyID, &gen.Model, &gen.Prompt, &gen.MediaType,
+ &gen.Status, &gen.MediaURL, &mediaURLsJSON, &gen.FileSizeBytes,
+ &gen.StorageType, &s3KeysJSON, &gen.UpstreamTaskID, &gen.ErrorMessage,
+ &gen.CreatedAt, &completedAt,
+ ); err != nil {
+ return nil, 0, err
+ }
+
+ if apiKeyID.Valid {
+ gen.APIKeyID = &apiKeyID.Int64
+ }
+ if completedAt.Valid {
+ gen.CompletedAt = &completedAt.Time
+ }
+ _ = json.Unmarshal(mediaURLsJSON, &gen.MediaURLs)
+ _ = json.Unmarshal(s3KeysJSON, &gen.S3ObjectKeys)
+ results = append(results, gen)
+ }
+
+ return results, total, rows.Err()
+}
+
+func (r *soraGenerationRepository) CountByUserAndStatus(ctx context.Context, userID int64, statuses []string) (int64, error) {
+ if len(statuses) == 0 {
+ return 0, nil
+ }
+
+ placeholders := make([]string, len(statuses))
+ args := []any{userID}
+ for i, s := range statuses {
+ placeholders[i] = fmt.Sprintf("$%d", i+2)
+ args = append(args, s)
+ }
+
+ var count int64
+ query := fmt.Sprintf("SELECT COUNT(*) FROM sora_generations WHERE user_id = $1 AND status IN (%s)", strings.Join(placeholders, ","))
+ err := r.sql.QueryRowContext(ctx, query, args...).Scan(&count)
+ return count, err
+}
diff --git a/backend/internal/repository/usage_cleanup_repo.go b/backend/internal/repository/usage_cleanup_repo.go
index 9c021357..1a25696e 100644
--- a/backend/internal/repository/usage_cleanup_repo.go
+++ b/backend/internal/repository/usage_cleanup_repo.go
@@ -362,7 +362,12 @@ func buildUsageCleanupWhere(filters service.UsageCleanupFilters) (string, []any)
idx++
}
}
- if filters.Stream != nil {
+ if filters.RequestType != nil {
+ condition, conditionArgs := buildRequestTypeFilterCondition(idx, *filters.RequestType)
+ conditions = append(conditions, condition)
+ args = append(args, conditionArgs...)
+ idx += len(conditionArgs)
+ } else if filters.Stream != nil {
conditions = append(conditions, fmt.Sprintf("stream = $%d", idx))
args = append(args, *filters.Stream)
idx++
diff --git a/backend/internal/repository/usage_cleanup_repo_test.go b/backend/internal/repository/usage_cleanup_repo_test.go
index 0ca30ec7..1ac7cca5 100644
--- a/backend/internal/repository/usage_cleanup_repo_test.go
+++ b/backend/internal/repository/usage_cleanup_repo_test.go
@@ -466,6 +466,38 @@ func TestBuildUsageCleanupWhere(t *testing.T) {
require.Equal(t, []any{start, end, userID, apiKeyID, accountID, groupID, "gpt-4", stream, billingType}, args)
}
+func TestBuildUsageCleanupWhereRequestTypePriority(t *testing.T) {
+ start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
+ end := start.Add(24 * time.Hour)
+ requestType := int16(service.RequestTypeWSV2)
+ stream := false
+
+ where, args := buildUsageCleanupWhere(service.UsageCleanupFilters{
+ StartTime: start,
+ EndTime: end,
+ RequestType: &requestType,
+ Stream: &stream,
+ })
+
+ require.Equal(t, "created_at >= $1 AND created_at <= $2 AND (request_type = $3 OR (request_type = 0 AND openai_ws_mode = TRUE))", where)
+ require.Equal(t, []any{start, end, requestType}, args)
+}
+
+func TestBuildUsageCleanupWhereRequestTypeLegacyFallback(t *testing.T) {
+ start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
+ end := start.Add(24 * time.Hour)
+ requestType := int16(service.RequestTypeStream)
+
+ where, args := buildUsageCleanupWhere(service.UsageCleanupFilters{
+ StartTime: start,
+ EndTime: end,
+ RequestType: &requestType,
+ })
+
+ require.Equal(t, "created_at >= $1 AND created_at <= $2 AND (request_type = $3 OR (request_type = 0 AND stream = TRUE AND openai_ws_mode = FALSE))", where)
+ require.Equal(t, []any{start, end, requestType}, args)
+}
+
func TestBuildUsageCleanupWhereModelEmpty(t *testing.T) {
start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
end := start.Add(24 * time.Hour)
diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go
index ce67ba4d..d30cc7dd 100644
--- a/backend/internal/repository/usage_log_repo.go
+++ b/backend/internal/repository/usage_log_repo.go
@@ -22,7 +22,7 @@ import (
"github.com/lib/pq"
)
-const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, reasoning_effort, cache_ttl_overridden, created_at"
+const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, reasoning_effort, cache_ttl_overridden, created_at"
// dateFormatWhitelist 将 granularity 参数映射为 PostgreSQL TO_CHAR 格式字符串,防止外部输入直接拼入 SQL
var dateFormatWhitelist = map[string]string{
@@ -98,6 +98,8 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
log.RequestID = requestID
rateMultiplier := log.RateMultiplier
+ log.SyncRequestTypeAndLegacyFields()
+ requestType := int16(log.RequestType)
query := `
INSERT INTO usage_logs (
@@ -123,7 +125,9 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
rate_multiplier,
account_rate_multiplier,
billing_type,
+ request_type,
stream,
+ openai_ws_mode,
duration_ms,
first_token_ms,
user_agent,
@@ -140,7 +144,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
$8, $9, $10, $11,
$12, $13,
$14, $15, $16, $17, $18, $19,
- $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33
+ $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35
)
ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING id, created_at
@@ -184,7 +188,9 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
rateMultiplier,
log.AccountRateMultiplier,
log.BillingType,
+ requestType,
log.Stream,
+ log.OpenAIWSMode,
duration,
firstToken,
userAgent,
@@ -492,25 +498,46 @@ func (r *usageLogRepository) fillDashboardUsageStatsAggregated(ctx context.Conte
}
func (r *usageLogRepository) fillDashboardUsageStatsFromUsageLogs(ctx context.Context, stats *DashboardStats, startUTC, endUTC, todayUTC, now time.Time) error {
- totalStatsQuery := `
+ todayEnd := todayUTC.Add(24 * time.Hour)
+ combinedStatsQuery := `
+ WITH scoped AS (
+ SELECT
+ created_at,
+ input_tokens,
+ output_tokens,
+ cache_creation_tokens,
+ cache_read_tokens,
+ total_cost,
+ actual_cost,
+ COALESCE(duration_ms, 0) AS duration_ms
+ FROM usage_logs
+ WHERE created_at >= LEAST($1::timestamptz, $3::timestamptz)
+ AND created_at < GREATEST($2::timestamptz, $4::timestamptz)
+ )
SELECT
- COUNT(*) as total_requests,
- COALESCE(SUM(input_tokens), 0) as total_input_tokens,
- COALESCE(SUM(output_tokens), 0) as total_output_tokens,
- COALESCE(SUM(cache_creation_tokens), 0) as total_cache_creation_tokens,
- COALESCE(SUM(cache_read_tokens), 0) as total_cache_read_tokens,
- COALESCE(SUM(total_cost), 0) as total_cost,
- COALESCE(SUM(actual_cost), 0) as total_actual_cost,
- COALESCE(SUM(COALESCE(duration_ms, 0)), 0) as total_duration_ms
- FROM usage_logs
- WHERE created_at >= $1 AND created_at < $2
+ COUNT(*) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz) AS total_requests,
+ COALESCE(SUM(input_tokens) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_input_tokens,
+ COALESCE(SUM(output_tokens) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_output_tokens,
+ COALESCE(SUM(cache_creation_tokens) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_cache_creation_tokens,
+ COALESCE(SUM(cache_read_tokens) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_cache_read_tokens,
+ COALESCE(SUM(total_cost) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_cost,
+ COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_actual_cost,
+ COALESCE(SUM(duration_ms) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_duration_ms,
+ COUNT(*) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz) AS today_requests,
+ COALESCE(SUM(input_tokens) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_input_tokens,
+ COALESCE(SUM(output_tokens) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_output_tokens,
+ COALESCE(SUM(cache_creation_tokens) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_cache_creation_tokens,
+ COALESCE(SUM(cache_read_tokens) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_cache_read_tokens,
+ COALESCE(SUM(total_cost) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_cost,
+ COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_actual_cost
+ FROM scoped
`
var totalDurationMs int64
if err := scanSingleRow(
ctx,
r.sql,
- totalStatsQuery,
- []any{startUTC, endUTC},
+ combinedStatsQuery,
+ []any{startUTC, endUTC, todayUTC, todayEnd},
&stats.TotalRequests,
&stats.TotalInputTokens,
&stats.TotalOutputTokens,
@@ -519,32 +546,6 @@ func (r *usageLogRepository) fillDashboardUsageStatsFromUsageLogs(ctx context.Co
&stats.TotalCost,
&stats.TotalActualCost,
&totalDurationMs,
- ); err != nil {
- return err
- }
- stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheCreationTokens + stats.TotalCacheReadTokens
- if stats.TotalRequests > 0 {
- stats.AverageDurationMs = float64(totalDurationMs) / float64(stats.TotalRequests)
- }
-
- todayEnd := todayUTC.Add(24 * time.Hour)
- todayStatsQuery := `
- SELECT
- COUNT(*) as today_requests,
- COALESCE(SUM(input_tokens), 0) as today_input_tokens,
- COALESCE(SUM(output_tokens), 0) as today_output_tokens,
- COALESCE(SUM(cache_creation_tokens), 0) as today_cache_creation_tokens,
- COALESCE(SUM(cache_read_tokens), 0) as today_cache_read_tokens,
- COALESCE(SUM(total_cost), 0) as today_cost,
- COALESCE(SUM(actual_cost), 0) as today_actual_cost
- FROM usage_logs
- WHERE created_at >= $1 AND created_at < $2
- `
- if err := scanSingleRow(
- ctx,
- r.sql,
- todayStatsQuery,
- []any{todayUTC, todayEnd},
&stats.TodayRequests,
&stats.TodayInputTokens,
&stats.TodayOutputTokens,
@@ -555,25 +556,28 @@ func (r *usageLogRepository) fillDashboardUsageStatsFromUsageLogs(ctx context.Co
); err != nil {
return err
}
- stats.TodayTokens = stats.TodayInputTokens + stats.TodayOutputTokens + stats.TodayCacheCreationTokens + stats.TodayCacheReadTokens
-
- activeUsersQuery := `
- SELECT COUNT(DISTINCT user_id) as active_users
- FROM usage_logs
- WHERE created_at >= $1 AND created_at < $2
- `
- if err := scanSingleRow(ctx, r.sql, activeUsersQuery, []any{todayUTC, todayEnd}, &stats.ActiveUsers); err != nil {
- return err
+ stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheCreationTokens + stats.TotalCacheReadTokens
+ if stats.TotalRequests > 0 {
+ stats.AverageDurationMs = float64(totalDurationMs) / float64(stats.TotalRequests)
}
+ stats.TodayTokens = stats.TodayInputTokens + stats.TodayOutputTokens + stats.TodayCacheCreationTokens + stats.TodayCacheReadTokens
+
hourStart := now.UTC().Truncate(time.Hour)
hourEnd := hourStart.Add(time.Hour)
- hourlyActiveQuery := `
- SELECT COUNT(DISTINCT user_id) as active_users
- FROM usage_logs
- WHERE created_at >= $1 AND created_at < $2
+ activeUsersQuery := `
+ WITH scoped AS (
+ SELECT user_id, created_at
+ FROM usage_logs
+ WHERE created_at >= LEAST($1::timestamptz, $3::timestamptz)
+ AND created_at < GREATEST($2::timestamptz, $4::timestamptz)
+ )
+ SELECT
+ COUNT(DISTINCT CASE WHEN created_at >= $1::timestamptz AND created_at < $2::timestamptz THEN user_id END) AS active_users,
+ COUNT(DISTINCT CASE WHEN created_at >= $3::timestamptz AND created_at < $4::timestamptz THEN user_id END) AS hourly_active_users
+ FROM scoped
`
- if err := scanSingleRow(ctx, r.sql, hourlyActiveQuery, []any{hourStart, hourEnd}, &stats.HourlyActiveUsers); err != nil {
+ if err := scanSingleRow(ctx, r.sql, activeUsersQuery, []any{todayUTC, todayEnd, hourStart, hourEnd}, &stats.ActiveUsers, &stats.HourlyActiveUsers); err != nil {
return err
}
@@ -968,6 +972,61 @@ func (r *usageLogRepository) GetAccountWindowStatsBatch(ctx context.Context, acc
return result, nil
}
+// GetGeminiUsageTotalsBatch 批量聚合 Gemini 账号在窗口内的 Pro/Flash 请求与用量。
+// 模型分类规则与 service.geminiModelClassFromName 一致:model 包含 flash/lite 视为 flash,其余视为 pro。
+func (r *usageLogRepository) GetGeminiUsageTotalsBatch(ctx context.Context, accountIDs []int64, startTime, endTime time.Time) (map[int64]service.GeminiUsageTotals, error) {
+ result := make(map[int64]service.GeminiUsageTotals, len(accountIDs))
+ if len(accountIDs) == 0 {
+ return result, nil
+ }
+
+ query := `
+ SELECT
+ account_id,
+ COALESCE(SUM(CASE WHEN LOWER(COALESCE(model, '')) LIKE '%flash%' OR LOWER(COALESCE(model, '')) LIKE '%lite%' THEN 1 ELSE 0 END), 0) AS flash_requests,
+ COALESCE(SUM(CASE WHEN LOWER(COALESCE(model, '')) LIKE '%flash%' OR LOWER(COALESCE(model, '')) LIKE '%lite%' THEN 0 ELSE 1 END), 0) AS pro_requests,
+ COALESCE(SUM(CASE WHEN LOWER(COALESCE(model, '')) LIKE '%flash%' OR LOWER(COALESCE(model, '')) LIKE '%lite%' THEN (input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens) ELSE 0 END), 0) AS flash_tokens,
+ COALESCE(SUM(CASE WHEN LOWER(COALESCE(model, '')) LIKE '%flash%' OR LOWER(COALESCE(model, '')) LIKE '%lite%' THEN 0 ELSE (input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens) END), 0) AS pro_tokens,
+ COALESCE(SUM(CASE WHEN LOWER(COALESCE(model, '')) LIKE '%flash%' OR LOWER(COALESCE(model, '')) LIKE '%lite%' THEN actual_cost ELSE 0 END), 0) AS flash_cost,
+ COALESCE(SUM(CASE WHEN LOWER(COALESCE(model, '')) LIKE '%flash%' OR LOWER(COALESCE(model, '')) LIKE '%lite%' THEN 0 ELSE actual_cost END), 0) AS pro_cost
+ FROM usage_logs
+ WHERE account_id = ANY($1) AND created_at >= $2 AND created_at < $3
+ GROUP BY account_id
+ `
+ rows, err := r.sql.QueryContext(ctx, query, pq.Array(accountIDs), startTime, endTime)
+ if err != nil {
+ return nil, err
+ }
+ defer func() { _ = rows.Close() }()
+
+ for rows.Next() {
+ var accountID int64
+ var totals service.GeminiUsageTotals
+ if err := rows.Scan(
+ &accountID,
+ &totals.FlashRequests,
+ &totals.ProRequests,
+ &totals.FlashTokens,
+ &totals.ProTokens,
+ &totals.FlashCost,
+ &totals.ProCost,
+ ); err != nil {
+ return nil, err
+ }
+ result[accountID] = totals
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+
+ for _, accountID := range accountIDs {
+ if _, ok := result[accountID]; !ok {
+ result[accountID] = service.GeminiUsageTotals{}
+ }
+ }
+ return result, nil
+}
+
// TrendDataPoint represents a single point in trend data
type TrendDataPoint = usagestats.TrendDataPoint
@@ -1399,10 +1458,7 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat
conditions = append(conditions, fmt.Sprintf("model = $%d", len(args)+1))
args = append(args, filters.Model)
}
- if filters.Stream != nil {
- conditions = append(conditions, fmt.Sprintf("stream = $%d", len(args)+1))
- args = append(args, *filters.Stream)
- }
+ conditions, args = appendRequestTypeOrStreamWhereCondition(conditions, args, filters.RequestType, filters.Stream)
if filters.BillingType != nil {
conditions = append(conditions, fmt.Sprintf("billing_type = $%d", len(args)+1))
args = append(args, int16(*filters.BillingType))
@@ -1598,7 +1654,7 @@ func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKe
}
// GetUsageTrendWithFilters returns usage trend data with optional filters
-func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) (results []TrendDataPoint, err error) {
+func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) (results []TrendDataPoint, err error) {
dateFormat := safeDateFormat(granularity)
query := fmt.Sprintf(`
@@ -1636,10 +1692,7 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
query += fmt.Sprintf(" AND model = $%d", len(args)+1)
args = append(args, model)
}
- if stream != nil {
- query += fmt.Sprintf(" AND stream = $%d", len(args)+1)
- args = append(args, *stream)
- }
+ query, args = appendRequestTypeOrStreamQueryFilter(query, args, requestType, stream)
if billingType != nil {
query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1)
args = append(args, int16(*billingType))
@@ -1667,7 +1720,7 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
}
// GetModelStatsWithFilters returns model statistics with optional filters
-func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) (results []ModelStat, err error) {
+func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) (results []ModelStat, err error) {
actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost"
// 当仅按 account_id 聚合时,实际费用使用账号倍率(total_cost * account_rate_multiplier)。
if accountID > 0 && userID == 0 && apiKeyID == 0 {
@@ -1704,10 +1757,7 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start
query += fmt.Sprintf(" AND group_id = $%d", len(args)+1)
args = append(args, groupID)
}
- if stream != nil {
- query += fmt.Sprintf(" AND stream = $%d", len(args)+1)
- args = append(args, *stream)
- }
+ query, args = appendRequestTypeOrStreamQueryFilter(query, args, requestType, stream)
if billingType != nil {
query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1)
args = append(args, int16(*billingType))
@@ -1734,6 +1784,77 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start
return results, nil
}
+// GetGroupStatsWithFilters returns group usage statistics with optional filters
+func (r *usageLogRepository) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) (results []usagestats.GroupStat, err error) {
+ query := `
+ SELECT
+ COALESCE(ul.group_id, 0) as group_id,
+ COALESCE(g.name, '') as group_name,
+ COUNT(*) as requests,
+ COALESCE(SUM(ul.input_tokens + ul.output_tokens + ul.cache_creation_tokens + ul.cache_read_tokens), 0) as total_tokens,
+ COALESCE(SUM(ul.total_cost), 0) as cost,
+ COALESCE(SUM(ul.actual_cost), 0) as actual_cost
+ FROM usage_logs ul
+ LEFT JOIN groups g ON g.id = ul.group_id
+ WHERE ul.created_at >= $1 AND ul.created_at < $2
+ `
+
+ args := []any{startTime, endTime}
+ if userID > 0 {
+ query += fmt.Sprintf(" AND ul.user_id = $%d", len(args)+1)
+ args = append(args, userID)
+ }
+ if apiKeyID > 0 {
+ query += fmt.Sprintf(" AND ul.api_key_id = $%d", len(args)+1)
+ args = append(args, apiKeyID)
+ }
+ if accountID > 0 {
+ query += fmt.Sprintf(" AND ul.account_id = $%d", len(args)+1)
+ args = append(args, accountID)
+ }
+ if groupID > 0 {
+ query += fmt.Sprintf(" AND ul.group_id = $%d", len(args)+1)
+ args = append(args, groupID)
+ }
+ query, args = appendRequestTypeOrStreamQueryFilter(query, args, requestType, stream)
+ if billingType != nil {
+ query += fmt.Sprintf(" AND ul.billing_type = $%d", len(args)+1)
+ args = append(args, int16(*billingType))
+ }
+ query += " GROUP BY ul.group_id, g.name ORDER BY total_tokens DESC"
+
+ rows, err := r.sql.QueryContext(ctx, query, args...)
+ if err != nil {
+ return nil, err
+ }
+ defer func() {
+ if closeErr := rows.Close(); closeErr != nil && err == nil {
+ err = closeErr
+ results = nil
+ }
+ }()
+
+ results = make([]usagestats.GroupStat, 0)
+ for rows.Next() {
+ var row usagestats.GroupStat
+ if err := rows.Scan(
+ &row.GroupID,
+ &row.GroupName,
+ &row.Requests,
+ &row.TotalTokens,
+ &row.Cost,
+ &row.ActualCost,
+ ); err != nil {
+ return nil, err
+ }
+ results = append(results, row)
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+ return results, nil
+}
+
// GetGlobalStats gets usage statistics for all users within a time range
func (r *usageLogRepository) GetGlobalStats(ctx context.Context, startTime, endTime time.Time) (*UsageStats, error) {
query := `
@@ -1794,10 +1915,7 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us
conditions = append(conditions, fmt.Sprintf("model = $%d", len(args)+1))
args = append(args, filters.Model)
}
- if filters.Stream != nil {
- conditions = append(conditions, fmt.Sprintf("stream = $%d", len(args)+1))
- args = append(args, *filters.Stream)
- }
+ conditions, args = appendRequestTypeOrStreamWhereCondition(conditions, args, filters.RequestType, filters.Stream)
if filters.BillingType != nil {
conditions = append(conditions, fmt.Sprintf("billing_type = $%d", len(args)+1))
args = append(args, int16(*filters.BillingType))
@@ -2017,7 +2135,7 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
}
}
- models, err := r.GetModelStatsWithFilters(ctx, startTime, endTime, 0, 0, accountID, 0, nil, nil)
+ models, err := r.GetModelStatsWithFilters(ctx, startTime, endTime, 0, 0, accountID, 0, nil, nil, nil)
if err != nil {
models = []ModelStat{}
}
@@ -2267,7 +2385,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
rateMultiplier float64
accountRateMultiplier sql.NullFloat64
billingType int16
+ requestTypeRaw int16
stream bool
+ openaiWSMode bool
durationMs sql.NullInt64
firstTokenMs sql.NullInt64
userAgent sql.NullString
@@ -2304,7 +2424,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
&rateMultiplier,
&accountRateMultiplier,
&billingType,
+ &requestTypeRaw,
&stream,
+ &openaiWSMode,
&durationMs,
&firstTokenMs,
&userAgent,
@@ -2340,11 +2462,16 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
RateMultiplier: rateMultiplier,
AccountRateMultiplier: nullFloat64Ptr(accountRateMultiplier),
BillingType: int8(billingType),
- Stream: stream,
+ RequestType: service.RequestTypeFromInt16(requestTypeRaw),
ImageCount: imageCount,
CacheTTLOverridden: cacheTTLOverridden,
CreatedAt: createdAt,
}
+ // 先回填 legacy 字段,再基于 legacy + request_type 计算最终请求类型,保证历史数据兼容。
+ log.Stream = stream
+ log.OpenAIWSMode = openaiWSMode
+ log.RequestType = log.EffectiveRequestType()
+ log.Stream, log.OpenAIWSMode = service.ApplyLegacyRequestFields(log.RequestType, stream, openaiWSMode)
if requestID.Valid {
log.RequestID = requestID.String
@@ -2438,6 +2565,50 @@ func buildWhere(conditions []string) string {
return "WHERE " + strings.Join(conditions, " AND ")
}
+func appendRequestTypeOrStreamWhereCondition(conditions []string, args []any, requestType *int16, stream *bool) ([]string, []any) {
+ if requestType != nil {
+ condition, conditionArgs := buildRequestTypeFilterCondition(len(args)+1, *requestType)
+ conditions = append(conditions, condition)
+ args = append(args, conditionArgs...)
+ return conditions, args
+ }
+ if stream != nil {
+ conditions = append(conditions, fmt.Sprintf("stream = $%d", len(args)+1))
+ args = append(args, *stream)
+ }
+ return conditions, args
+}
+
+func appendRequestTypeOrStreamQueryFilter(query string, args []any, requestType *int16, stream *bool) (string, []any) {
+ if requestType != nil {
+ condition, conditionArgs := buildRequestTypeFilterCondition(len(args)+1, *requestType)
+ query += " AND " + condition
+ args = append(args, conditionArgs...)
+ return query, args
+ }
+ if stream != nil {
+ query += fmt.Sprintf(" AND stream = $%d", len(args)+1)
+ args = append(args, *stream)
+ }
+ return query, args
+}
+
+// buildRequestTypeFilterCondition 在 request_type 过滤时兼容 legacy 字段,避免历史数据漏查。
+func buildRequestTypeFilterCondition(startArgIndex int, requestType int16) (string, []any) {
+ normalized := service.RequestTypeFromInt16(requestType)
+ requestTypeArg := int16(normalized)
+ switch normalized {
+ case service.RequestTypeSync:
+ return fmt.Sprintf("(request_type = $%d OR (request_type = %d AND stream = FALSE AND openai_ws_mode = FALSE))", startArgIndex, int16(service.RequestTypeUnknown)), []any{requestTypeArg}
+ case service.RequestTypeStream:
+ return fmt.Sprintf("(request_type = $%d OR (request_type = %d AND stream = TRUE AND openai_ws_mode = FALSE))", startArgIndex, int16(service.RequestTypeUnknown)), []any{requestTypeArg}
+ case service.RequestTypeWSV2:
+ return fmt.Sprintf("(request_type = $%d OR (request_type = %d AND openai_ws_mode = TRUE))", startArgIndex, int16(service.RequestTypeUnknown)), []any{requestTypeArg}
+ default:
+ return fmt.Sprintf("request_type = $%d", startArgIndex), []any{requestTypeArg}
+ }
+}
+
func nullInt64(v *int64) sql.NullInt64 {
if v == nil {
return sql.NullInt64{}
diff --git a/backend/internal/repository/usage_log_repo_integration_test.go b/backend/internal/repository/usage_log_repo_integration_test.go
index 8cb3aab1..4d50f7de 100644
--- a/backend/internal/repository/usage_log_repo_integration_test.go
+++ b/backend/internal/repository/usage_log_repo_integration_test.go
@@ -130,6 +130,62 @@ func (s *UsageLogRepoSuite) TestGetByID_ReturnsAccountRateMultiplier() {
s.Require().InEpsilon(0.5, *got.AccountRateMultiplier, 0.0001)
}
+func (s *UsageLogRepoSuite) TestGetByID_ReturnsOpenAIWSMode() {
+ user := mustCreateUser(s.T(), s.client, &service.User{Email: "getbyid-ws@test.com"})
+ apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-getbyid-ws", Name: "k"})
+ account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-getbyid-ws"})
+
+ log := &service.UsageLog{
+ UserID: user.ID,
+ APIKeyID: apiKey.ID,
+ AccountID: account.ID,
+ RequestID: uuid.New().String(),
+ Model: "gpt-5.3-codex",
+ InputTokens: 10,
+ OutputTokens: 20,
+ TotalCost: 1.0,
+ ActualCost: 1.0,
+ OpenAIWSMode: true,
+ CreatedAt: timezone.Today().Add(3 * time.Hour),
+ }
+ _, err := s.repo.Create(s.ctx, log)
+ s.Require().NoError(err)
+
+ got, err := s.repo.GetByID(s.ctx, log.ID)
+ s.Require().NoError(err)
+ s.Require().True(got.OpenAIWSMode)
+}
+
+func (s *UsageLogRepoSuite) TestGetByID_ReturnsRequestTypeAndLegacyFallback() {
+ user := mustCreateUser(s.T(), s.client, &service.User{Email: "getbyid-request-type@test.com"})
+ apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-getbyid-request-type", Name: "k"})
+ account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-getbyid-request-type"})
+
+ log := &service.UsageLog{
+ UserID: user.ID,
+ APIKeyID: apiKey.ID,
+ AccountID: account.ID,
+ RequestID: uuid.New().String(),
+ Model: "gpt-5.3-codex",
+ RequestType: service.RequestTypeWSV2,
+ Stream: true,
+ OpenAIWSMode: false,
+ InputTokens: 10,
+ OutputTokens: 20,
+ TotalCost: 1.0,
+ ActualCost: 1.0,
+ CreatedAt: timezone.Today().Add(4 * time.Hour),
+ }
+ _, err := s.repo.Create(s.ctx, log)
+ s.Require().NoError(err)
+
+ got, err := s.repo.GetByID(s.ctx, log.ID)
+ s.Require().NoError(err)
+ s.Require().Equal(service.RequestTypeWSV2, got.RequestType)
+ s.Require().True(got.Stream)
+ s.Require().True(got.OpenAIWSMode)
+}
+
// --- Delete ---
func (s *UsageLogRepoSuite) TestDelete() {
@@ -944,17 +1000,17 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() {
endTime := base.Add(48 * time.Hour)
// Test with user filter
- trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, 0, 0, 0, "", nil, nil)
+ trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, 0, 0, 0, "", nil, nil, nil)
s.Require().NoError(err, "GetUsageTrendWithFilters user filter")
s.Require().Len(trend, 2)
// Test with apiKey filter
- trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", 0, apiKey.ID, 0, 0, "", nil, nil)
+ trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", 0, apiKey.ID, 0, 0, "", nil, nil, nil)
s.Require().NoError(err, "GetUsageTrendWithFilters apiKey filter")
s.Require().Len(trend, 2)
// Test with both filters
- trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, apiKey.ID, 0, 0, "", nil, nil)
+ trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, apiKey.ID, 0, 0, "", nil, nil, nil)
s.Require().NoError(err, "GetUsageTrendWithFilters both filters")
s.Require().Len(trend, 2)
}
@@ -971,7 +1027,7 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() {
startTime := base.Add(-1 * time.Hour)
endTime := base.Add(3 * time.Hour)
- trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "hour", user.ID, 0, 0, 0, "", nil, nil)
+ trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "hour", user.ID, 0, 0, 0, "", nil, nil, nil)
s.Require().NoError(err, "GetUsageTrendWithFilters hourly")
s.Require().Len(trend, 2)
}
@@ -1017,17 +1073,17 @@ func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() {
endTime := base.Add(2 * time.Hour)
// Test with user filter
- stats, err := s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, user.ID, 0, 0, 0, nil, nil)
+ stats, err := s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, user.ID, 0, 0, 0, nil, nil, nil)
s.Require().NoError(err, "GetModelStatsWithFilters user filter")
s.Require().Len(stats, 2)
// Test with apiKey filter
- stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, apiKey.ID, 0, 0, nil, nil)
+ stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, apiKey.ID, 0, 0, nil, nil, nil)
s.Require().NoError(err, "GetModelStatsWithFilters apiKey filter")
s.Require().Len(stats, 2)
// Test with account filter
- stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, 0, account.ID, 0, nil, nil)
+ stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, 0, account.ID, 0, nil, nil, nil)
s.Require().NoError(err, "GetModelStatsWithFilters account filter")
s.Require().Len(stats, 2)
}
diff --git a/backend/internal/repository/usage_log_repo_request_type_test.go b/backend/internal/repository/usage_log_repo_request_type_test.go
new file mode 100644
index 00000000..95cf2a2d
--- /dev/null
+++ b/backend/internal/repository/usage_log_repo_request_type_test.go
@@ -0,0 +1,327 @@
+package repository
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+ "reflect"
+ "testing"
+ "time"
+
+ "github.com/DATA-DOG/go-sqlmock"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/stretchr/testify/require"
+)
+
+func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) {
+ db, mock := newSQLMock(t)
+ repo := &usageLogRepository{sql: db}
+
+ createdAt := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC)
+ log := &service.UsageLog{
+ UserID: 1,
+ APIKeyID: 2,
+ AccountID: 3,
+ RequestID: "req-1",
+ Model: "gpt-5",
+ InputTokens: 10,
+ OutputTokens: 20,
+ TotalCost: 1,
+ ActualCost: 1,
+ BillingType: service.BillingTypeBalance,
+ RequestType: service.RequestTypeWSV2,
+ Stream: false,
+ OpenAIWSMode: false,
+ CreatedAt: createdAt,
+ }
+
+ mock.ExpectQuery("INSERT INTO usage_logs").
+ WithArgs(
+ log.UserID,
+ log.APIKeyID,
+ log.AccountID,
+ log.RequestID,
+ log.Model,
+ sqlmock.AnyArg(), // group_id
+ sqlmock.AnyArg(), // subscription_id
+ log.InputTokens,
+ log.OutputTokens,
+ log.CacheCreationTokens,
+ log.CacheReadTokens,
+ log.CacheCreation5mTokens,
+ log.CacheCreation1hTokens,
+ log.InputCost,
+ log.OutputCost,
+ log.CacheCreationCost,
+ log.CacheReadCost,
+ log.TotalCost,
+ log.ActualCost,
+ log.RateMultiplier,
+ log.AccountRateMultiplier,
+ log.BillingType,
+ int16(service.RequestTypeWSV2),
+ true,
+ true,
+ sqlmock.AnyArg(), // duration_ms
+ sqlmock.AnyArg(), // first_token_ms
+ sqlmock.AnyArg(), // user_agent
+ sqlmock.AnyArg(), // ip_address
+ log.ImageCount,
+ sqlmock.AnyArg(), // image_size
+ sqlmock.AnyArg(), // media_type
+ sqlmock.AnyArg(), // reasoning_effort
+ log.CacheTTLOverridden,
+ createdAt,
+ ).
+ WillReturnRows(sqlmock.NewRows([]string{"id", "created_at"}).AddRow(int64(99), createdAt))
+
+ inserted, err := repo.Create(context.Background(), log)
+ require.NoError(t, err)
+ require.True(t, inserted)
+ require.Equal(t, int64(99), log.ID)
+ require.Equal(t, service.RequestTypeWSV2, log.RequestType)
+ require.True(t, log.Stream)
+ require.True(t, log.OpenAIWSMode)
+ require.NoError(t, mock.ExpectationsWereMet())
+}
+
+func TestUsageLogRepositoryListWithFiltersRequestTypePriority(t *testing.T) {
+ db, mock := newSQLMock(t)
+ repo := &usageLogRepository{sql: db}
+
+ requestType := int16(service.RequestTypeWSV2)
+ stream := false
+ filters := usagestats.UsageLogFilters{
+ RequestType: &requestType,
+ Stream: &stream,
+ }
+
+ mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM usage_logs WHERE \\(request_type = \\$1 OR \\(request_type = 0 AND openai_ws_mode = TRUE\\)\\)").
+ WithArgs(requestType).
+ WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(int64(0)))
+ mock.ExpectQuery("SELECT .* FROM usage_logs WHERE \\(request_type = \\$1 OR \\(request_type = 0 AND openai_ws_mode = TRUE\\)\\) ORDER BY id DESC LIMIT \\$2 OFFSET \\$3").
+ WithArgs(requestType, 20, 0).
+ WillReturnRows(sqlmock.NewRows([]string{"id"}))
+
+ logs, page, err := repo.ListWithFilters(context.Background(), pagination.PaginationParams{Page: 1, PageSize: 20}, filters)
+ require.NoError(t, err)
+ require.Empty(t, logs)
+ require.NotNil(t, page)
+ require.Equal(t, int64(0), page.Total)
+ require.NoError(t, mock.ExpectationsWereMet())
+}
+
+func TestUsageLogRepositoryGetUsageTrendWithFiltersRequestTypePriority(t *testing.T) {
+ db, mock := newSQLMock(t)
+ repo := &usageLogRepository{sql: db}
+
+ start := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC)
+ end := start.Add(24 * time.Hour)
+ requestType := int16(service.RequestTypeStream)
+ stream := true
+
+ mock.ExpectQuery("AND \\(request_type = \\$3 OR \\(request_type = 0 AND stream = TRUE AND openai_ws_mode = FALSE\\)\\)").
+ WithArgs(start, end, requestType).
+ WillReturnRows(sqlmock.NewRows([]string{"date", "requests", "input_tokens", "output_tokens", "cache_tokens", "total_tokens", "cost", "actual_cost"}))
+
+ trend, err := repo.GetUsageTrendWithFilters(context.Background(), start, end, "day", 0, 0, 0, 0, "", &requestType, &stream, nil)
+ require.NoError(t, err)
+ require.Empty(t, trend)
+ require.NoError(t, mock.ExpectationsWereMet())
+}
+
+func TestUsageLogRepositoryGetModelStatsWithFiltersRequestTypePriority(t *testing.T) {
+ db, mock := newSQLMock(t)
+ repo := &usageLogRepository{sql: db}
+
+ start := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC)
+ end := start.Add(24 * time.Hour)
+ requestType := int16(service.RequestTypeWSV2)
+ stream := false
+
+ mock.ExpectQuery("AND \\(request_type = \\$3 OR \\(request_type = 0 AND openai_ws_mode = TRUE\\)\\)").
+ WithArgs(start, end, requestType).
+ WillReturnRows(sqlmock.NewRows([]string{"model", "requests", "input_tokens", "output_tokens", "total_tokens", "cost", "actual_cost"}))
+
+ stats, err := repo.GetModelStatsWithFilters(context.Background(), start, end, 0, 0, 0, 0, &requestType, &stream, nil)
+ require.NoError(t, err)
+ require.Empty(t, stats)
+ require.NoError(t, mock.ExpectationsWereMet())
+}
+
+func TestUsageLogRepositoryGetStatsWithFiltersRequestTypePriority(t *testing.T) {
+ db, mock := newSQLMock(t)
+ repo := &usageLogRepository{sql: db}
+
+ requestType := int16(service.RequestTypeSync)
+ stream := true
+ filters := usagestats.UsageLogFilters{
+ RequestType: &requestType,
+ Stream: &stream,
+ }
+
+ mock.ExpectQuery("FROM usage_logs\\s+WHERE \\(request_type = \\$1 OR \\(request_type = 0 AND stream = FALSE AND openai_ws_mode = FALSE\\)\\)").
+ WithArgs(requestType).
+ WillReturnRows(sqlmock.NewRows([]string{
+ "total_requests",
+ "total_input_tokens",
+ "total_output_tokens",
+ "total_cache_tokens",
+ "total_cost",
+ "total_actual_cost",
+ "total_account_cost",
+ "avg_duration_ms",
+ }).AddRow(int64(1), int64(2), int64(3), int64(4), 1.2, 1.0, 1.2, 20.0))
+
+ stats, err := repo.GetStatsWithFilters(context.Background(), filters)
+ require.NoError(t, err)
+ require.Equal(t, int64(1), stats.TotalRequests)
+ require.Equal(t, int64(9), stats.TotalTokens)
+ require.NoError(t, mock.ExpectationsWereMet())
+}
+
+func TestBuildRequestTypeFilterConditionLegacyFallback(t *testing.T) {
+ tests := []struct {
+ name string
+ request int16
+ wantWhere string
+ wantArg int16
+ }{
+ {
+ name: "sync_with_legacy_fallback",
+ request: int16(service.RequestTypeSync),
+ wantWhere: "(request_type = $3 OR (request_type = 0 AND stream = FALSE AND openai_ws_mode = FALSE))",
+ wantArg: int16(service.RequestTypeSync),
+ },
+ {
+ name: "stream_with_legacy_fallback",
+ request: int16(service.RequestTypeStream),
+ wantWhere: "(request_type = $3 OR (request_type = 0 AND stream = TRUE AND openai_ws_mode = FALSE))",
+ wantArg: int16(service.RequestTypeStream),
+ },
+ {
+ name: "ws_v2_with_legacy_fallback",
+ request: int16(service.RequestTypeWSV2),
+ wantWhere: "(request_type = $3 OR (request_type = 0 AND openai_ws_mode = TRUE))",
+ wantArg: int16(service.RequestTypeWSV2),
+ },
+ {
+ name: "invalid_request_type_normalized_to_unknown",
+ request: int16(99),
+ wantWhere: "request_type = $3",
+ wantArg: int16(service.RequestTypeUnknown),
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ where, args := buildRequestTypeFilterCondition(3, tt.request)
+ require.Equal(t, tt.wantWhere, where)
+ require.Equal(t, []any{tt.wantArg}, args)
+ })
+ }
+}
+
+type usageLogScannerStub struct {
+ values []any
+}
+
+func (s usageLogScannerStub) Scan(dest ...any) error {
+ if len(dest) != len(s.values) {
+ return fmt.Errorf("scan arg count mismatch: got %d want %d", len(dest), len(s.values))
+ }
+ for i := range dest {
+ dv := reflect.ValueOf(dest[i])
+ if dv.Kind() != reflect.Ptr {
+ return fmt.Errorf("dest[%d] is not pointer", i)
+ }
+ dv.Elem().Set(reflect.ValueOf(s.values[i]))
+ }
+ return nil
+}
+
+func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
+ t.Run("request_type_ws_v2_overrides_legacy", func(t *testing.T) {
+ now := time.Now().UTC()
+ log, err := scanUsageLog(usageLogScannerStub{values: []any{
+ int64(1), // id
+ int64(10), // user_id
+ int64(20), // api_key_id
+ int64(30), // account_id
+ sql.NullString{Valid: true, String: "req-1"},
+ "gpt-5", // model
+ sql.NullInt64{}, // group_id
+ sql.NullInt64{}, // subscription_id
+ 1, // input_tokens
+ 2, // output_tokens
+ 3, // cache_creation_tokens
+ 4, // cache_read_tokens
+ 5, // cache_creation_5m_tokens
+ 6, // cache_creation_1h_tokens
+ 0.1, // input_cost
+ 0.2, // output_cost
+ 0.3, // cache_creation_cost
+ 0.4, // cache_read_cost
+ 1.0, // total_cost
+ 0.9, // actual_cost
+ 1.0, // rate_multiplier
+ sql.NullFloat64{}, // account_rate_multiplier
+ int16(service.BillingTypeBalance),
+ int16(service.RequestTypeWSV2),
+ false, // legacy stream
+ false, // legacy openai ws
+ sql.NullInt64{},
+ sql.NullInt64{},
+ sql.NullString{},
+ sql.NullString{},
+ 0,
+ sql.NullString{},
+ sql.NullString{},
+ sql.NullString{},
+ false,
+ now,
+ }})
+ require.NoError(t, err)
+ require.Equal(t, service.RequestTypeWSV2, log.RequestType)
+ require.True(t, log.Stream)
+ require.True(t, log.OpenAIWSMode)
+ })
+
+ t.Run("request_type_unknown_falls_back_to_legacy", func(t *testing.T) {
+ now := time.Now().UTC()
+ log, err := scanUsageLog(usageLogScannerStub{values: []any{
+ int64(2),
+ int64(11),
+ int64(21),
+ int64(31),
+ sql.NullString{Valid: true, String: "req-2"},
+ "gpt-5",
+ sql.NullInt64{},
+ sql.NullInt64{},
+ 1, 2, 3, 4, 5, 6,
+ 0.1, 0.2, 0.3, 0.4, 1.0, 0.9,
+ 1.0,
+ sql.NullFloat64{},
+ int16(service.BillingTypeBalance),
+ int16(service.RequestTypeUnknown),
+ true,
+ false,
+ sql.NullInt64{},
+ sql.NullInt64{},
+ sql.NullString{},
+ sql.NullString{},
+ 0,
+ sql.NullString{},
+ sql.NullString{},
+ sql.NullString{},
+ false,
+ now,
+ }})
+ require.NoError(t, err)
+ require.Equal(t, service.RequestTypeStream, log.RequestType)
+ require.True(t, log.Stream)
+ require.False(t, log.OpenAIWSMode)
+ })
+}
diff --git a/backend/internal/repository/user_group_rate_repo.go b/backend/internal/repository/user_group_rate_repo.go
index eb65403b..e3b11096 100644
--- a/backend/internal/repository/user_group_rate_repo.go
+++ b/backend/internal/repository/user_group_rate_repo.go
@@ -6,6 +6,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/lib/pq"
)
type userGroupRateRepository struct {
@@ -41,6 +42,59 @@ func (r *userGroupRateRepository) GetByUserID(ctx context.Context, userID int64)
return result, nil
}
+// GetByUserIDs 批量获取多个用户的专属分组倍率。
+// 返回结构:map[userID]map[groupID]rate
+func (r *userGroupRateRepository) GetByUserIDs(ctx context.Context, userIDs []int64) (map[int64]map[int64]float64, error) {
+ result := make(map[int64]map[int64]float64, len(userIDs))
+ if len(userIDs) == 0 {
+ return result, nil
+ }
+
+ uniqueIDs := make([]int64, 0, len(userIDs))
+ seen := make(map[int64]struct{}, len(userIDs))
+ for _, userID := range userIDs {
+ if userID <= 0 {
+ continue
+ }
+ if _, exists := seen[userID]; exists {
+ continue
+ }
+ seen[userID] = struct{}{}
+ uniqueIDs = append(uniqueIDs, userID)
+ result[userID] = make(map[int64]float64)
+ }
+ if len(uniqueIDs) == 0 {
+ return result, nil
+ }
+
+ rows, err := r.sql.QueryContext(ctx, `
+ SELECT user_id, group_id, rate_multiplier
+ FROM user_group_rate_multipliers
+ WHERE user_id = ANY($1)
+ `, pq.Array(uniqueIDs))
+ if err != nil {
+ return nil, err
+ }
+ defer func() { _ = rows.Close() }()
+
+ for rows.Next() {
+ var userID int64
+ var groupID int64
+ var rate float64
+ if err := rows.Scan(&userID, &groupID, &rate); err != nil {
+ return nil, err
+ }
+ if _, ok := result[userID]; !ok {
+ result[userID] = make(map[int64]float64)
+ }
+ result[userID][groupID] = rate
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+ return result, nil
+}
+
// GetByUserAndGroup 获取用户在特定分组的专属倍率
func (r *userGroupRateRepository) GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error) {
query := `SELECT rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2`
@@ -65,33 +119,43 @@ func (r *userGroupRateRepository) SyncUserGroupRates(ctx context.Context, userID
// 分离需要删除和需要 upsert 的记录
var toDelete []int64
- toUpsert := make(map[int64]float64)
+ upsertGroupIDs := make([]int64, 0, len(rates))
+ upsertRates := make([]float64, 0, len(rates))
for groupID, rate := range rates {
if rate == nil {
toDelete = append(toDelete, groupID)
} else {
- toUpsert[groupID] = *rate
+ upsertGroupIDs = append(upsertGroupIDs, groupID)
+ upsertRates = append(upsertRates, *rate)
}
}
// 删除指定的记录
- for _, groupID := range toDelete {
- _, err := r.sql.ExecContext(ctx,
- `DELETE FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2`,
- userID, groupID)
- if err != nil {
+ if len(toDelete) > 0 {
+ if _, err := r.sql.ExecContext(ctx,
+ `DELETE FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = ANY($2)`,
+ userID, pq.Array(toDelete)); err != nil {
return err
}
}
// Upsert 记录
now := time.Now()
- for groupID, rate := range toUpsert {
+ if len(upsertGroupIDs) > 0 {
_, err := r.sql.ExecContext(ctx, `
INSERT INTO user_group_rate_multipliers (user_id, group_id, rate_multiplier, created_at, updated_at)
- VALUES ($1, $2, $3, $4, $4)
- ON CONFLICT (user_id, group_id) DO UPDATE SET rate_multiplier = $3, updated_at = $4
- `, userID, groupID, rate, now)
+ SELECT
+ $1::bigint,
+ data.group_id,
+ data.rate_multiplier,
+ $2::timestamptz,
+ $2::timestamptz
+ FROM unnest($3::bigint[], $4::double precision[]) AS data(group_id, rate_multiplier)
+ ON CONFLICT (user_id, group_id)
+ DO UPDATE SET
+ rate_multiplier = EXCLUDED.rate_multiplier,
+ updated_at = EXCLUDED.updated_at
+ `, userID, now, pq.Array(upsertGroupIDs), pq.Array(upsertRates))
if err != nil {
return err
}
diff --git a/backend/internal/repository/user_repo.go b/backend/internal/repository/user_repo.go
index 17674291..05b68968 100644
--- a/backend/internal/repository/user_repo.go
+++ b/backend/internal/repository/user_repo.go
@@ -61,6 +61,7 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
SetBalance(userIn.Balance).
SetConcurrency(userIn.Concurrency).
SetStatus(userIn.Status).
+ SetSoraStorageQuotaBytes(userIn.SoraStorageQuotaBytes).
Save(ctx)
if err != nil {
return translatePersistenceError(err, nil, service.ErrEmailExists)
@@ -143,6 +144,8 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
SetBalance(userIn.Balance).
SetConcurrency(userIn.Concurrency).
SetStatus(userIn.Status).
+ SetSoraStorageQuotaBytes(userIn.SoraStorageQuotaBytes).
+ SetSoraStorageUsedBytes(userIn.SoraStorageUsedBytes).
Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, service.ErrEmailExists)
@@ -363,10 +366,79 @@ func (r *userRepository) UpdateConcurrency(ctx context.Context, id int64, amount
return nil
}
+// AddSoraStorageUsageWithQuota 原子累加 Sora 存储用量,并在有配额时校验不超额。
+func (r *userRepository) AddSoraStorageUsageWithQuota(ctx context.Context, userID int64, deltaBytes int64, effectiveQuota int64) (int64, error) {
+ if deltaBytes <= 0 {
+ user, err := r.GetByID(ctx, userID)
+ if err != nil {
+ return 0, err
+ }
+ return user.SoraStorageUsedBytes, nil
+ }
+ var newUsed int64
+ err := scanSingleRow(ctx, r.sql, `
+ UPDATE users
+ SET sora_storage_used_bytes = sora_storage_used_bytes + $2
+ WHERE id = $1
+ AND ($3 = 0 OR sora_storage_used_bytes + $2 <= $3)
+ RETURNING sora_storage_used_bytes
+ `, []any{userID, deltaBytes, effectiveQuota}, &newUsed)
+ if err == nil {
+ return newUsed, nil
+ }
+ if errors.Is(err, sql.ErrNoRows) {
+ // 区分用户不存在和配额冲突
+ exists, existsErr := r.client.User.Query().Where(dbuser.IDEQ(userID)).Exist(ctx)
+ if existsErr != nil {
+ return 0, existsErr
+ }
+ if !exists {
+ return 0, service.ErrUserNotFound
+ }
+ return 0, service.ErrSoraStorageQuotaExceeded
+ }
+ return 0, err
+}
+
+// ReleaseSoraStorageUsageAtomic 原子释放 Sora 存储用量,并保证不低于 0。
+func (r *userRepository) ReleaseSoraStorageUsageAtomic(ctx context.Context, userID int64, deltaBytes int64) (int64, error) {
+ if deltaBytes <= 0 {
+ user, err := r.GetByID(ctx, userID)
+ if err != nil {
+ return 0, err
+ }
+ return user.SoraStorageUsedBytes, nil
+ }
+ var newUsed int64
+ err := scanSingleRow(ctx, r.sql, `
+ UPDATE users
+ SET sora_storage_used_bytes = GREATEST(sora_storage_used_bytes - $2, 0)
+ WHERE id = $1
+ RETURNING sora_storage_used_bytes
+ `, []any{userID, deltaBytes}, &newUsed)
+ if err != nil {
+ if errors.Is(err, sql.ErrNoRows) {
+ return 0, service.ErrUserNotFound
+ }
+ return 0, err
+ }
+ return newUsed, nil
+}
+
func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) {
return r.client.User.Query().Where(dbuser.EmailEQ(email)).Exist(ctx)
}
+func (r *userRepository) AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error {
+ client := clientFromContext(ctx, r.client)
+ return client.UserAllowedGroup.Create().
+ SetUserID(userID).
+ SetGroupID(groupID).
+ OnConflictColumns(userallowedgroup.FieldUserID, userallowedgroup.FieldGroupID).
+ DoNothing().
+ Exec(ctx)
+}
+
func (r *userRepository) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) {
// 仅操作 user_allowed_groups 联接表,legacy users.allowed_groups 列已弃用。
affected, err := r.client.UserAllowedGroup.Delete().
diff --git a/backend/internal/repository/wire.go b/backend/internal/repository/wire.go
index 0878c43d..2344035c 100644
--- a/backend/internal/repository/wire.go
+++ b/backend/internal/repository/wire.go
@@ -79,6 +79,7 @@ var ProviderSet = wire.NewSet(
NewTimeoutCounterCache,
ProvideConcurrencyCache,
ProvideSessionLimitCache,
+ NewRPMCache,
NewDashboardCache,
NewEmailCache,
NewIdentityCache,
@@ -106,6 +107,7 @@ var ProviderSet = wire.NewSet(
NewOpenAIOAuthClient,
NewGeminiOAuthClient,
NewGeminiCliCodeAssistClient,
+ NewGeminiDriveClient,
ProvideEnt,
ProvideSQLDB,
diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go
index 5be390df..2738ed18 100644
--- a/backend/internal/server/api_contract_test.go
+++ b/backend/internal/server/api_contract_test.go
@@ -186,11 +186,12 @@ func TestAPIContracts(t *testing.T) {
"image_price_1k": null,
"image_price_2k": null,
"image_price_4k": null,
- "sora_image_price_360": null,
- "sora_image_price_540": null,
- "sora_video_price_per_request": null,
- "sora_video_price_per_request_hd": null,
- "claude_code_only": false,
+ "sora_image_price_360": null,
+ "sora_image_price_540": null,
+ "sora_storage_quota_bytes": 0,
+ "sora_video_price_per_request": null,
+ "sora_video_price_per_request_hd": null,
+ "claude_code_only": false,
"fallback_group_id": null,
"fallback_group_id_on_invalid_request": null,
"created_at": "2025-01-02T03:04:05Z",
@@ -384,10 +385,12 @@ func TestAPIContracts(t *testing.T) {
"user_id": 1,
"api_key_id": 100,
"account_id": 200,
- "request_id": "req_123",
- "model": "claude-3",
- "group_id": null,
- "subscription_id": null,
+ "request_id": "req_123",
+ "model": "claude-3",
+ "request_type": "stream",
+ "openai_ws_mode": false,
+ "group_id": null,
+ "subscription_id": null,
"input_tokens": 10,
"output_tokens": 20,
"cache_creation_tokens": 1,
@@ -496,18 +499,21 @@ func TestAPIContracts(t *testing.T) {
"doc_url": "https://docs.example.com",
"default_concurrency": 5,
"default_balance": 1.25,
+ "default_subscriptions": [],
"enable_model_fallback": false,
"fallback_model_anthropic": "claude-3-5-sonnet-20241022",
"fallback_model_antigravity": "gemini-2.5-pro",
"fallback_model_gemini": "gemini-2.5-pro",
- "fallback_model_openai": "gpt-4o",
- "enable_identity_patch": true,
- "identity_patch_prompt": "",
- "invitation_code_enabled": false,
- "home_content": "",
+ "fallback_model_openai": "gpt-4o",
+ "enable_identity_patch": true,
+ "identity_patch_prompt": "",
+ "sora_client_enabled": false,
+ "invitation_code_enabled": false,
+ "home_content": "",
"hide_ccs_import_button": false,
"purchase_subscription_enabled": false,
- "purchase_subscription_url": ""
+ "purchase_subscription_url": "",
+ "min_claude_code_version": ""
}
}`,
},
@@ -615,12 +621,12 @@ func newContractDeps(t *testing.T) *contractDeps {
settingRepo := newStubSettingRepo()
settingService := service.NewSettingService(settingRepo, cfg)
- adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, nil, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil)
+ adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, nil, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil)
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
- adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil)
- adminAccountHandler := adminhandler.NewAccountHandler(adminService, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
+ adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil, nil)
+ adminAccountHandler := adminhandler.NewAccountHandler(adminService, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
jwtAuth := func(c *gin.Context) {
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{
@@ -775,6 +781,10 @@ func (r *stubUserRepo) RemoveGroupFromAllowedGroups(ctx context.Context, groupID
return 0, errors.New("not implemented")
}
+func (r *stubUserRepo) AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error {
+ return errors.New("not implemented")
+}
+
func (r *stubUserRepo) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
return errors.New("not implemented")
}
@@ -1555,11 +1565,15 @@ func (r *stubUsageLogRepo) GetDashboardStats(ctx context.Context) (*usagestats.D
return nil, errors.New("not implemented")
}
-func (r *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) {
+func (r *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) {
return nil, errors.New("not implemented")
}
-func (r *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) {
+func (r *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) {
+ return nil, errors.New("not implemented")
+}
+
+func (r *stubUsageLogRepo) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) {
return nil, errors.New("not implemented")
}
diff --git a/backend/internal/server/middleware/admin_auth_test.go b/backend/internal/server/middleware/admin_auth_test.go
index 7b6d4ce8..033a5b77 100644
--- a/backend/internal/server/middleware/admin_auth_test.go
+++ b/backend/internal/server/middleware/admin_auth_test.go
@@ -19,7 +19,7 @@ func TestAdminAuthJWTValidatesTokenVersion(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{JWT: config.JWTConfig{Secret: "test-secret", ExpireHour: 1}}
- authService := service.NewAuthService(nil, nil, nil, cfg, nil, nil, nil, nil, nil)
+ authService := service.NewAuthService(nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
admin := &service.User{
ID: 1,
@@ -181,6 +181,10 @@ func (s *stubUserRepo) RemoveGroupFromAllowedGroups(ctx context.Context, groupID
panic("unexpected RemoveGroupFromAllowedGroups call")
}
+func (s *stubUserRepo) AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error {
+ panic("unexpected AddGroupToAllowedGroups call")
+}
+
func (s *stubUserRepo) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
panic("unexpected UpdateTotpSecret call")
}
diff --git a/backend/internal/server/middleware/api_key_auth.go b/backend/internal/server/middleware/api_key_auth.go
index 8fa3517a..19f97239 100644
--- a/backend/internal/server/middleware/api_key_auth.go
+++ b/backend/internal/server/middleware/api_key_auth.go
@@ -97,7 +97,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
// 注意:错误信息故意模糊,避免暴露具体的 IP 限制机制
if len(apiKey.IPWhitelist) > 0 || len(apiKey.IPBlacklist) > 0 {
clientIP := ip.GetTrustedClientIP(c)
- allowed, _ := ip.CheckIPRestriction(clientIP, apiKey.IPWhitelist, apiKey.IPBlacklist)
+ allowed, _ := ip.CheckIPRestrictionWithCompiledRules(clientIP, apiKey.CompiledIPWhitelist, apiKey.CompiledIPBlacklist)
if !allowed {
AbortWithError(c, 403, "ACCESS_DENIED", "Access denied")
return
diff --git a/backend/internal/server/middleware/api_key_auth_google.go b/backend/internal/server/middleware/api_key_auth_google.go
index 9da1b1c6..84d93edc 100644
--- a/backend/internal/server/middleware/api_key_auth_google.go
+++ b/backend/internal/server/middleware/api_key_auth_google.go
@@ -80,17 +80,25 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs
abortWithGoogleError(c, 403, "No active subscription found for this group")
return
}
- if err := subscriptionService.ValidateSubscription(c.Request.Context(), subscription); err != nil {
- abortWithGoogleError(c, 403, err.Error())
- return
- }
- _ = subscriptionService.CheckAndActivateWindow(c.Request.Context(), subscription)
- _ = subscriptionService.CheckAndResetWindows(c.Request.Context(), subscription)
- if err := subscriptionService.CheckUsageLimits(c.Request.Context(), subscription, apiKey.Group, 0); err != nil {
- abortWithGoogleError(c, 429, err.Error())
+
+ needsMaintenance, err := subscriptionService.ValidateAndCheckLimits(subscription, apiKey.Group)
+ if err != nil {
+ status := 403
+ if errors.Is(err, service.ErrDailyLimitExceeded) ||
+ errors.Is(err, service.ErrWeeklyLimitExceeded) ||
+ errors.Is(err, service.ErrMonthlyLimitExceeded) {
+ status = 429
+ }
+ abortWithGoogleError(c, status, err.Error())
return
}
+
c.Set(string(ContextKeySubscription), subscription)
+
+ if needsMaintenance {
+ maintenanceCopy := *subscription
+ subscriptionService.DoWindowMaintenance(&maintenanceCopy)
+ }
} else {
if apiKey.User.Balance <= 0 {
abortWithGoogleError(c, 403, "Insufficient account balance")
diff --git a/backend/internal/server/middleware/api_key_auth_google_test.go b/backend/internal/server/middleware/api_key_auth_google_test.go
index e4e0e253..2124c86c 100644
--- a/backend/internal/server/middleware/api_key_auth_google_test.go
+++ b/backend/internal/server/middleware/api_key_auth_google_test.go
@@ -23,6 +23,15 @@ type fakeAPIKeyRepo struct {
updateLastUsed func(ctx context.Context, id int64, usedAt time.Time) error
}
+type fakeGoogleSubscriptionRepo struct {
+ getActive func(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error)
+ updateStatus func(ctx context.Context, subscriptionID int64, status string) error
+ activateWindow func(ctx context.Context, id int64, start time.Time) error
+ resetDaily func(ctx context.Context, id int64, start time.Time) error
+ resetWeekly func(ctx context.Context, id int64, start time.Time) error
+ resetMonthly func(ctx context.Context, id int64, start time.Time) error
+}
+
func (f fakeAPIKeyRepo) Create(ctx context.Context, key *service.APIKey) error {
return errors.New("not implemented")
}
@@ -87,6 +96,85 @@ func (f fakeAPIKeyRepo) UpdateLastUsed(ctx context.Context, id int64, usedAt tim
return nil
}
+func (f fakeGoogleSubscriptionRepo) Create(ctx context.Context, sub *service.UserSubscription) error {
+ return errors.New("not implemented")
+}
+func (f fakeGoogleSubscriptionRepo) GetByID(ctx context.Context, id int64) (*service.UserSubscription, error) {
+ return nil, errors.New("not implemented")
+}
+func (f fakeGoogleSubscriptionRepo) GetByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) {
+ return nil, errors.New("not implemented")
+}
+func (f fakeGoogleSubscriptionRepo) GetActiveByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) {
+ if f.getActive != nil {
+ return f.getActive(ctx, userID, groupID)
+ }
+ return nil, errors.New("not implemented")
+}
+func (f fakeGoogleSubscriptionRepo) Update(ctx context.Context, sub *service.UserSubscription) error {
+ return errors.New("not implemented")
+}
+func (f fakeGoogleSubscriptionRepo) Delete(ctx context.Context, id int64) error {
+ return errors.New("not implemented")
+}
+func (f fakeGoogleSubscriptionRepo) ListByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) {
+ return nil, errors.New("not implemented")
+}
+func (f fakeGoogleSubscriptionRepo) ListActiveByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) {
+ return nil, errors.New("not implemented")
+}
+func (f fakeGoogleSubscriptionRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) {
+ return nil, nil, errors.New("not implemented")
+}
+func (f fakeGoogleSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
+ return nil, nil, errors.New("not implemented")
+}
+func (f fakeGoogleSubscriptionRepo) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) {
+ return false, errors.New("not implemented")
+}
+func (f fakeGoogleSubscriptionRepo) ExtendExpiry(ctx context.Context, subscriptionID int64, newExpiresAt time.Time) error {
+ return errors.New("not implemented")
+}
+func (f fakeGoogleSubscriptionRepo) UpdateStatus(ctx context.Context, subscriptionID int64, status string) error {
+ if f.updateStatus != nil {
+ return f.updateStatus(ctx, subscriptionID, status)
+ }
+ return errors.New("not implemented")
+}
+func (f fakeGoogleSubscriptionRepo) UpdateNotes(ctx context.Context, subscriptionID int64, notes string) error {
+ return errors.New("not implemented")
+}
+func (f fakeGoogleSubscriptionRepo) ActivateWindows(ctx context.Context, id int64, start time.Time) error {
+ if f.activateWindow != nil {
+ return f.activateWindow(ctx, id, start)
+ }
+ return errors.New("not implemented")
+}
+func (f fakeGoogleSubscriptionRepo) ResetDailyUsage(ctx context.Context, id int64, start time.Time) error {
+ if f.resetDaily != nil {
+ return f.resetDaily(ctx, id, start)
+ }
+ return errors.New("not implemented")
+}
+func (f fakeGoogleSubscriptionRepo) ResetWeeklyUsage(ctx context.Context, id int64, start time.Time) error {
+ if f.resetWeekly != nil {
+ return f.resetWeekly(ctx, id, start)
+ }
+ return errors.New("not implemented")
+}
+func (f fakeGoogleSubscriptionRepo) ResetMonthlyUsage(ctx context.Context, id int64, start time.Time) error {
+ if f.resetMonthly != nil {
+ return f.resetMonthly(ctx, id, start)
+ }
+ return errors.New("not implemented")
+}
+func (f fakeGoogleSubscriptionRepo) IncrementUsage(ctx context.Context, id int64, costUSD float64) error {
+ return errors.New("not implemented")
+}
+func (f fakeGoogleSubscriptionRepo) BatchUpdateExpiredStatus(ctx context.Context) (int64, error) {
+ return 0, errors.New("not implemented")
+}
+
type googleErrorResponse struct {
Error struct {
Code int `json:"code"`
@@ -505,3 +593,85 @@ func TestApiKeyAuthWithSubscriptionGoogle_TouchesLastUsedInStandardMode(t *testi
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, 1, touchCalls)
}
+
+func TestApiKeyAuthWithSubscriptionGoogle_SubscriptionLimitExceededReturns429(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ limit := 1.0
+ group := &service.Group{
+ ID: 77,
+ Name: "gemini-sub",
+ Status: service.StatusActive,
+ Platform: service.PlatformGemini,
+ Hydrated: true,
+ SubscriptionType: service.SubscriptionTypeSubscription,
+ DailyLimitUSD: &limit,
+ }
+ user := &service.User{
+ ID: 999,
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ Balance: 10,
+ Concurrency: 3,
+ }
+ apiKey := &service.APIKey{
+ ID: 501,
+ UserID: user.ID,
+ Key: "google-sub-limit",
+ Status: service.StatusActive,
+ User: user,
+ Group: group,
+ }
+ apiKey.GroupID = &group.ID
+
+ apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{
+ getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
+ if key != apiKey.Key {
+ return nil, service.ErrAPIKeyNotFound
+ }
+ clone := *apiKey
+ return &clone, nil
+ },
+ })
+
+ now := time.Now()
+ sub := &service.UserSubscription{
+ ID: 601,
+ UserID: user.ID,
+ GroupID: group.ID,
+ Status: service.SubscriptionStatusActive,
+ ExpiresAt: now.Add(24 * time.Hour),
+ DailyWindowStart: &now,
+ DailyUsageUSD: 10,
+ }
+ subscriptionService := service.NewSubscriptionService(nil, fakeGoogleSubscriptionRepo{
+ getActive: func(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) {
+ if userID != user.ID || groupID != group.ID {
+ return nil, service.ErrSubscriptionNotFound
+ }
+ clone := *sub
+ return &clone, nil
+ },
+ updateStatus: func(ctx context.Context, subscriptionID int64, status string) error { return nil },
+ activateWindow: func(ctx context.Context, id int64, start time.Time) error { return nil },
+ resetDaily: func(ctx context.Context, id int64, start time.Time) error { return nil },
+ resetWeekly: func(ctx context.Context, id int64, start time.Time) error { return nil },
+ resetMonthly: func(ctx context.Context, id int64, start time.Time) error { return nil },
+ }, nil, nil, &config.Config{RunMode: config.RunModeStandard})
+
+ r := gin.New()
+ r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, &config.Config{RunMode: config.RunModeStandard}))
+ r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
+
+ req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
+ req.Header.Set("x-goog-api-key", apiKey.Key)
+ rec := httptest.NewRecorder()
+ r.ServeHTTP(rec, req)
+
+ require.Equal(t, http.StatusTooManyRequests, rec.Code)
+ var resp googleErrorResponse
+ require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
+ require.Equal(t, http.StatusTooManyRequests, resp.Error.Code)
+ require.Equal(t, "RESOURCE_EXHAUSTED", resp.Error.Status)
+ require.Contains(t, resp.Error.Message, "daily usage limit exceeded")
+}
diff --git a/backend/internal/server/middleware/jwt_auth_test.go b/backend/internal/server/middleware/jwt_auth_test.go
index bc320958..f8839cfe 100644
--- a/backend/internal/server/middleware/jwt_auth_test.go
+++ b/backend/internal/server/middleware/jwt_auth_test.go
@@ -40,7 +40,7 @@ func newJWTTestEnv(users map[int64]*service.User) (*gin.Engine, *service.AuthSer
cfg.JWT.AccessTokenExpireMinutes = 60
userRepo := &stubJWTUserRepo{users: users}
- authSvc := service.NewAuthService(userRepo, nil, nil, cfg, nil, nil, nil, nil, nil)
+ authSvc := service.NewAuthService(userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
userSvc := service.NewUserService(userRepo, nil, nil)
mw := NewJWTAuthMiddleware(authSvc, userSvc)
diff --git a/backend/internal/server/middleware/security_headers.go b/backend/internal/server/middleware/security_headers.go
index 67b19c09..f061db90 100644
--- a/backend/internal/server/middleware/security_headers.go
+++ b/backend/internal/server/middleware/security_headers.go
@@ -54,6 +54,10 @@ func SecurityHeaders(cfg config.CSPConfig) gin.HandlerFunc {
c.Header("X-Content-Type-Options", "nosniff")
c.Header("X-Frame-Options", "DENY")
c.Header("Referrer-Policy", "strict-origin-when-cross-origin")
+ if isAPIRoutePath(c) {
+ c.Next()
+ return
+ }
if cfg.Enabled {
// Generate nonce for this request
@@ -73,6 +77,18 @@ func SecurityHeaders(cfg config.CSPConfig) gin.HandlerFunc {
}
}
+func isAPIRoutePath(c *gin.Context) bool {
+ if c == nil || c.Request == nil || c.Request.URL == nil {
+ return false
+ }
+ path := c.Request.URL.Path
+ return strings.HasPrefix(path, "/v1/") ||
+ strings.HasPrefix(path, "/v1beta/") ||
+ strings.HasPrefix(path, "/antigravity/") ||
+ strings.HasPrefix(path, "/sora/") ||
+ strings.HasPrefix(path, "/responses")
+}
+
// enhanceCSPPolicy ensures the CSP policy includes nonce support and Cloudflare Insights domain.
// This allows the application to work correctly even if the config file has an older CSP policy.
func enhanceCSPPolicy(policy string) string {
diff --git a/backend/internal/server/middleware/security_headers_test.go b/backend/internal/server/middleware/security_headers_test.go
index 43462b82..5a779825 100644
--- a/backend/internal/server/middleware/security_headers_test.go
+++ b/backend/internal/server/middleware/security_headers_test.go
@@ -131,6 +131,26 @@ func TestSecurityHeaders(t *testing.T) {
assert.Contains(t, csp, CloudflareInsightsDomain)
})
+ t.Run("api_route_skips_csp_nonce_generation", func(t *testing.T) {
+ cfg := config.CSPConfig{
+ Enabled: true,
+ Policy: "default-src 'self'; script-src 'self' __CSP_NONCE__",
+ }
+ middleware := SecurityHeaders(cfg)
+
+ w := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(w)
+ c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
+
+ middleware(c)
+
+ assert.Equal(t, "nosniff", w.Header().Get("X-Content-Type-Options"))
+ assert.Equal(t, "DENY", w.Header().Get("X-Frame-Options"))
+ assert.Equal(t, "strict-origin-when-cross-origin", w.Header().Get("Referrer-Policy"))
+ assert.Empty(t, w.Header().Get("Content-Security-Policy"))
+ assert.Empty(t, GetNonceFromContext(c))
+ })
+
t.Run("csp_enabled_with_nonce_placeholder", func(t *testing.T) {
cfg := config.CSPConfig{
Enabled: true,
diff --git a/backend/internal/server/router.go b/backend/internal/server/router.go
index fb91bc0e..07b51f23 100644
--- a/backend/internal/server/router.go
+++ b/backend/internal/server/router.go
@@ -75,6 +75,7 @@ func registerRoutes(
// 注册各模块路由
routes.RegisterAuthRoutes(v1, h, jwtAuth, redisClient)
routes.RegisterUserRoutes(v1, h, jwtAuth)
+ routes.RegisterSoraClientRoutes(v1, h, jwtAuth)
routes.RegisterAdminRoutes(v1, h, adminAuth)
routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, opsService, cfg)
}
diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go
index 36efacc8..c36c36a0 100644
--- a/backend/internal/server/routes/admin.go
+++ b/backend/internal/server/routes/admin.go
@@ -55,6 +55,9 @@ func RegisterAdminRoutes(
// 系统设置
registerSettingsRoutes(admin, h)
+ // 数据管理
+ registerDataManagementRoutes(admin, h)
+
// 运维监控(Ops)
registerOpsRoutes(admin, h)
@@ -72,6 +75,16 @@ func RegisterAdminRoutes(
// 错误透传规则管理
registerErrorPassthroughRoutes(admin, h)
+
+ // API Key 管理
+ registerAdminAPIKeyRoutes(admin, h)
+ }
+}
+
+func registerAdminAPIKeyRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
+ apiKeys := admin.Group("/api-keys")
+ {
+ apiKeys.PUT("/:id", h.Admin.APIKey.UpdateGroup)
}
}
@@ -171,6 +184,7 @@ func registerDashboardRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
dashboard.GET("/realtime", h.Admin.Dashboard.GetRealtimeMetrics)
dashboard.GET("/trend", h.Admin.Dashboard.GetUsageTrend)
dashboard.GET("/models", h.Admin.Dashboard.GetModelStats)
+ dashboard.GET("/groups", h.Admin.Dashboard.GetGroupStats)
dashboard.GET("/api-keys-trend", h.Admin.Dashboard.GetAPIKeyUsageTrend)
dashboard.GET("/users-trend", h.Admin.Dashboard.GetUserUsageTrend)
dashboard.POST("/users-usage", h.Admin.Dashboard.GetBatchUsersUsage)
@@ -231,6 +245,7 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
accounts.POST("/:id/clear-error", h.Admin.Account.ClearError)
accounts.GET("/:id/usage", h.Admin.Account.GetUsage)
accounts.GET("/:id/today-stats", h.Admin.Account.GetTodayStats)
+ accounts.POST("/today-stats/batch", h.Admin.Account.GetBatchTodayStats)
accounts.POST("/:id/clear-rate-limit", h.Admin.Account.ClearRateLimit)
accounts.GET("/:id/temp-unschedulable", h.Admin.Account.GetTempUnschedulable)
accounts.DELETE("/:id/temp-unschedulable", h.Admin.Account.ClearTempUnschedulable)
@@ -337,6 +352,7 @@ func registerRedeemCodeRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
codes.GET("/stats", h.Admin.Redeem.GetStats)
codes.GET("/export", h.Admin.Redeem.Export)
codes.GET("/:id", h.Admin.Redeem.GetByID)
+ codes.POST("/create-and-redeem", h.Admin.Redeem.CreateAndRedeem)
codes.POST("/generate", h.Admin.Redeem.Generate)
codes.DELETE("/:id", h.Admin.Redeem.Delete)
codes.POST("/batch-delete", h.Admin.Redeem.BatchDelete)
@@ -370,6 +386,38 @@ func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
// 流超时处理配置
adminSettings.GET("/stream-timeout", h.Admin.Setting.GetStreamTimeoutSettings)
adminSettings.PUT("/stream-timeout", h.Admin.Setting.UpdateStreamTimeoutSettings)
+ // Sora S3 存储配置
+ adminSettings.GET("/sora-s3", h.Admin.Setting.GetSoraS3Settings)
+ adminSettings.PUT("/sora-s3", h.Admin.Setting.UpdateSoraS3Settings)
+ adminSettings.POST("/sora-s3/test", h.Admin.Setting.TestSoraS3Connection)
+ adminSettings.GET("/sora-s3/profiles", h.Admin.Setting.ListSoraS3Profiles)
+ adminSettings.POST("/sora-s3/profiles", h.Admin.Setting.CreateSoraS3Profile)
+ adminSettings.PUT("/sora-s3/profiles/:profile_id", h.Admin.Setting.UpdateSoraS3Profile)
+ adminSettings.DELETE("/sora-s3/profiles/:profile_id", h.Admin.Setting.DeleteSoraS3Profile)
+ adminSettings.POST("/sora-s3/profiles/:profile_id/activate", h.Admin.Setting.SetActiveSoraS3Profile)
+ }
+}
+
+func registerDataManagementRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
+ dataManagement := admin.Group("/data-management")
+ {
+ dataManagement.GET("/agent/health", h.Admin.DataManagement.GetAgentHealth)
+ dataManagement.GET("/config", h.Admin.DataManagement.GetConfig)
+ dataManagement.PUT("/config", h.Admin.DataManagement.UpdateConfig)
+ dataManagement.GET("/sources/:source_type/profiles", h.Admin.DataManagement.ListSourceProfiles)
+ dataManagement.POST("/sources/:source_type/profiles", h.Admin.DataManagement.CreateSourceProfile)
+ dataManagement.PUT("/sources/:source_type/profiles/:profile_id", h.Admin.DataManagement.UpdateSourceProfile)
+ dataManagement.DELETE("/sources/:source_type/profiles/:profile_id", h.Admin.DataManagement.DeleteSourceProfile)
+ dataManagement.POST("/sources/:source_type/profiles/:profile_id/activate", h.Admin.DataManagement.SetActiveSourceProfile)
+ dataManagement.POST("/s3/test", h.Admin.DataManagement.TestS3)
+ dataManagement.GET("/s3/profiles", h.Admin.DataManagement.ListS3Profiles)
+ dataManagement.POST("/s3/profiles", h.Admin.DataManagement.CreateS3Profile)
+ dataManagement.PUT("/s3/profiles/:profile_id", h.Admin.DataManagement.UpdateS3Profile)
+ dataManagement.DELETE("/s3/profiles/:profile_id", h.Admin.DataManagement.DeleteS3Profile)
+ dataManagement.POST("/s3/profiles/:profile_id/activate", h.Admin.DataManagement.SetActiveS3Profile)
+ dataManagement.POST("/backups", h.Admin.DataManagement.CreateBackupJob)
+ dataManagement.GET("/backups", h.Admin.DataManagement.ListBackupJobs)
+ dataManagement.GET("/backups/:job_id", h.Admin.DataManagement.GetBackupJob)
}
}
diff --git a/backend/internal/server/routes/gateway.go b/backend/internal/server/routes/gateway.go
index 930c8b9e..6bd91b85 100644
--- a/backend/internal/server/routes/gateway.go
+++ b/backend/internal/server/routes/gateway.go
@@ -43,6 +43,7 @@ func RegisterGatewayRoutes(
gateway.GET("/usage", h.Gateway.Usage)
// OpenAI Responses API
gateway.POST("/responses", h.OpenAIGateway.Responses)
+ gateway.GET("/responses", h.OpenAIGateway.ResponsesWebSocket)
// 明确阻止旧协议入口:OpenAI 仅支持 Responses API,避免客户端误解为会自动路由到其它平台。
gateway.POST("/chat/completions", func(c *gin.Context) {
c.JSON(http.StatusBadRequest, gin.H{
@@ -69,6 +70,7 @@ func RegisterGatewayRoutes(
// OpenAI Responses API(不带v1前缀的别名)
r.POST("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), h.OpenAIGateway.Responses)
+ r.GET("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), h.OpenAIGateway.ResponsesWebSocket)
// Antigravity 模型列表
r.GET("/antigravity/models", gin.HandlerFunc(apiKeyAuth), h.Gateway.AntigravityModels)
diff --git a/backend/internal/server/routes/sora_client.go b/backend/internal/server/routes/sora_client.go
new file mode 100644
index 00000000..40ae0436
--- /dev/null
+++ b/backend/internal/server/routes/sora_client.go
@@ -0,0 +1,33 @@
+package routes
+
+import (
+ "github.com/Wei-Shaw/sub2api/internal/handler"
+ "github.com/Wei-Shaw/sub2api/internal/server/middleware"
+
+ "github.com/gin-gonic/gin"
+)
+
+// RegisterSoraClientRoutes 注册 Sora 客户端 API 路由(需要用户认证)。
+func RegisterSoraClientRoutes(
+ v1 *gin.RouterGroup,
+ h *handler.Handlers,
+ jwtAuth middleware.JWTAuthMiddleware,
+) {
+ if h.SoraClient == nil {
+ return
+ }
+
+ authenticated := v1.Group("/sora")
+ authenticated.Use(gin.HandlerFunc(jwtAuth))
+ {
+ authenticated.POST("/generate", h.SoraClient.Generate)
+ authenticated.GET("/generations", h.SoraClient.ListGenerations)
+ authenticated.GET("/generations/:id", h.SoraClient.GetGeneration)
+ authenticated.DELETE("/generations/:id", h.SoraClient.DeleteGeneration)
+ authenticated.POST("/generations/:id/cancel", h.SoraClient.CancelGeneration)
+ authenticated.POST("/generations/:id/save", h.SoraClient.SaveToStorage)
+ authenticated.GET("/quota", h.SoraClient.GetQuota)
+ authenticated.GET("/models", h.SoraClient.GetModels)
+ authenticated.GET("/storage-status", h.SoraClient.GetStorageStatus)
+ }
+}
diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go
index 50fdac88..c76c817e 100644
--- a/backend/internal/service/account.go
+++ b/backend/internal/service/account.go
@@ -3,6 +3,8 @@ package service
import (
"encoding/json"
+ "hash/fnv"
+ "reflect"
"sort"
"strconv"
"strings"
@@ -50,6 +52,14 @@ type Account struct {
AccountGroups []AccountGroup
GroupIDs []int64
Groups []*Group
+
+ // model_mapping 热路径缓存(非持久化字段)
+ modelMappingCache map[string]string
+ modelMappingCacheReady bool
+ modelMappingCacheCredentialsPtr uintptr
+ modelMappingCacheRawPtr uintptr
+ modelMappingCacheRawLen int
+ modelMappingCacheRawSig uint64
}
type TempUnschedulableRule struct {
@@ -349,6 +359,39 @@ func parseTempUnschedInt(value any) int {
}
func (a *Account) GetModelMapping() map[string]string {
+ credentialsPtr := mapPtr(a.Credentials)
+ rawMapping, _ := a.Credentials["model_mapping"].(map[string]any)
+ rawPtr := mapPtr(rawMapping)
+ rawLen := len(rawMapping)
+ rawSig := uint64(0)
+ rawSigReady := false
+
+ if a.modelMappingCacheReady &&
+ a.modelMappingCacheCredentialsPtr == credentialsPtr &&
+ a.modelMappingCacheRawPtr == rawPtr &&
+ a.modelMappingCacheRawLen == rawLen {
+ rawSig = modelMappingSignature(rawMapping)
+ rawSigReady = true
+ if a.modelMappingCacheRawSig == rawSig {
+ return a.modelMappingCache
+ }
+ }
+
+ mapping := a.resolveModelMapping(rawMapping)
+ if !rawSigReady {
+ rawSig = modelMappingSignature(rawMapping)
+ }
+
+ a.modelMappingCache = mapping
+ a.modelMappingCacheReady = true
+ a.modelMappingCacheCredentialsPtr = credentialsPtr
+ a.modelMappingCacheRawPtr = rawPtr
+ a.modelMappingCacheRawLen = rawLen
+ a.modelMappingCacheRawSig = rawSig
+ return mapping
+}
+
+func (a *Account) resolveModelMapping(rawMapping map[string]any) map[string]string {
if a.Credentials == nil {
// Antigravity 平台使用默认映射
if a.Platform == domain.PlatformAntigravity {
@@ -356,32 +399,31 @@ func (a *Account) GetModelMapping() map[string]string {
}
return nil
}
- raw, ok := a.Credentials["model_mapping"]
- if !ok || raw == nil {
+ if len(rawMapping) == 0 {
// Antigravity 平台使用默认映射
if a.Platform == domain.PlatformAntigravity {
return domain.DefaultAntigravityModelMapping
}
return nil
}
- if m, ok := raw.(map[string]any); ok {
- result := make(map[string]string)
- for k, v := range m {
- if s, ok := v.(string); ok {
- result[k] = s
- }
- }
- if len(result) > 0 {
- if a.Platform == domain.PlatformAntigravity {
- ensureAntigravityDefaultPassthroughs(result, []string{
- "gemini-3-flash",
- "gemini-3.1-pro-high",
- "gemini-3.1-pro-low",
- })
- }
- return result
+
+ result := make(map[string]string)
+ for k, v := range rawMapping {
+ if s, ok := v.(string); ok {
+ result[k] = s
}
}
+ if len(result) > 0 {
+ if a.Platform == domain.PlatformAntigravity {
+ ensureAntigravityDefaultPassthroughs(result, []string{
+ "gemini-3-flash",
+ "gemini-3.1-pro-high",
+ "gemini-3.1-pro-low",
+ })
+ }
+ return result
+ }
+
// Antigravity 平台使用默认映射
if a.Platform == domain.PlatformAntigravity {
return domain.DefaultAntigravityModelMapping
@@ -389,6 +431,37 @@ func (a *Account) GetModelMapping() map[string]string {
return nil
}
+func mapPtr(m map[string]any) uintptr {
+ if m == nil {
+ return 0
+ }
+ return reflect.ValueOf(m).Pointer()
+}
+
+func modelMappingSignature(rawMapping map[string]any) uint64 {
+ if len(rawMapping) == 0 {
+ return 0
+ }
+ keys := make([]string, 0, len(rawMapping))
+ for k := range rawMapping {
+ keys = append(keys, k)
+ }
+ sort.Strings(keys)
+
+ h := fnv.New64a()
+ for _, k := range keys {
+ _, _ = h.Write([]byte(k))
+ _, _ = h.Write([]byte{0})
+ if v, ok := rawMapping[k].(string); ok {
+ _, _ = h.Write([]byte(v))
+ } else {
+ _, _ = h.Write([]byte{1})
+ }
+ _, _ = h.Write([]byte{0xff})
+ }
+ return h.Sum64()
+}
+
func ensureAntigravityDefaultPassthrough(mapping map[string]string, model string) {
if mapping == nil || model == "" {
return
@@ -742,6 +815,159 @@ func (a *Account) IsOpenAIPassthroughEnabled() bool {
return false
}
+// IsOpenAIResponsesWebSocketV2Enabled 返回 OpenAI 账号是否开启 Responses WebSocket v2。
+//
+// 分类型新字段:
+// - OAuth 账号:accounts.extra.openai_oauth_responses_websockets_v2_enabled
+// - API Key 账号:accounts.extra.openai_apikey_responses_websockets_v2_enabled
+//
+// 兼容字段:
+// - accounts.extra.responses_websockets_v2_enabled
+// - accounts.extra.openai_ws_enabled(历史开关)
+//
+// 优先级:
+// 1. 按账号类型读取分类型字段
+// 2. 分类型字段缺失时,回退兼容字段
+func (a *Account) IsOpenAIResponsesWebSocketV2Enabled() bool {
+ if a == nil || !a.IsOpenAI() || a.Extra == nil {
+ return false
+ }
+ if a.IsOpenAIOAuth() {
+ if enabled, ok := a.Extra["openai_oauth_responses_websockets_v2_enabled"].(bool); ok {
+ return enabled
+ }
+ }
+ if a.IsOpenAIApiKey() {
+ if enabled, ok := a.Extra["openai_apikey_responses_websockets_v2_enabled"].(bool); ok {
+ return enabled
+ }
+ }
+ if enabled, ok := a.Extra["responses_websockets_v2_enabled"].(bool); ok {
+ return enabled
+ }
+ if enabled, ok := a.Extra["openai_ws_enabled"].(bool); ok {
+ return enabled
+ }
+ return false
+}
+
+const (
+ OpenAIWSIngressModeOff = "off"
+ OpenAIWSIngressModeShared = "shared"
+ OpenAIWSIngressModeDedicated = "dedicated"
+)
+
+func normalizeOpenAIWSIngressMode(mode string) string {
+ switch strings.ToLower(strings.TrimSpace(mode)) {
+ case OpenAIWSIngressModeOff:
+ return OpenAIWSIngressModeOff
+ case OpenAIWSIngressModeShared:
+ return OpenAIWSIngressModeShared
+ case OpenAIWSIngressModeDedicated:
+ return OpenAIWSIngressModeDedicated
+ default:
+ return ""
+ }
+}
+
+func normalizeOpenAIWSIngressDefaultMode(mode string) string {
+ if normalized := normalizeOpenAIWSIngressMode(mode); normalized != "" {
+ return normalized
+ }
+ return OpenAIWSIngressModeShared
+}
+
+// ResolveOpenAIResponsesWebSocketV2Mode 返回账号在 WSv2 ingress 下的有效模式(off/shared/dedicated)。
+//
+// 优先级:
+// 1. 分类型 mode 新字段(string)
+// 2. 分类型 enabled 旧字段(bool)
+// 3. 兼容 enabled 旧字段(bool)
+// 4. defaultMode(非法时回退 shared)
+func (a *Account) ResolveOpenAIResponsesWebSocketV2Mode(defaultMode string) string {
+ resolvedDefault := normalizeOpenAIWSIngressDefaultMode(defaultMode)
+ if a == nil || !a.IsOpenAI() {
+ return OpenAIWSIngressModeOff
+ }
+ if a.Extra == nil {
+ return resolvedDefault
+ }
+
+ resolveModeString := func(key string) (string, bool) {
+ raw, ok := a.Extra[key]
+ if !ok {
+ return "", false
+ }
+ mode, ok := raw.(string)
+ if !ok {
+ return "", false
+ }
+ normalized := normalizeOpenAIWSIngressMode(mode)
+ if normalized == "" {
+ return "", false
+ }
+ return normalized, true
+ }
+ resolveBoolMode := func(key string) (string, bool) {
+ raw, ok := a.Extra[key]
+ if !ok {
+ return "", false
+ }
+ enabled, ok := raw.(bool)
+ if !ok {
+ return "", false
+ }
+ if enabled {
+ return OpenAIWSIngressModeShared, true
+ }
+ return OpenAIWSIngressModeOff, true
+ }
+
+ if a.IsOpenAIOAuth() {
+ if mode, ok := resolveModeString("openai_oauth_responses_websockets_v2_mode"); ok {
+ return mode
+ }
+ if mode, ok := resolveBoolMode("openai_oauth_responses_websockets_v2_enabled"); ok {
+ return mode
+ }
+ }
+ if a.IsOpenAIApiKey() {
+ if mode, ok := resolveModeString("openai_apikey_responses_websockets_v2_mode"); ok {
+ return mode
+ }
+ if mode, ok := resolveBoolMode("openai_apikey_responses_websockets_v2_enabled"); ok {
+ return mode
+ }
+ }
+ if mode, ok := resolveBoolMode("responses_websockets_v2_enabled"); ok {
+ return mode
+ }
+ if mode, ok := resolveBoolMode("openai_ws_enabled"); ok {
+ return mode
+ }
+ return resolvedDefault
+}
+
+// IsOpenAIWSForceHTTPEnabled 返回账号级“强制 HTTP”开关。
+// 字段:accounts.extra.openai_ws_force_http。
+func (a *Account) IsOpenAIWSForceHTTPEnabled() bool {
+ if a == nil || !a.IsOpenAI() || a.Extra == nil {
+ return false
+ }
+ enabled, ok := a.Extra["openai_ws_force_http"].(bool)
+ return ok && enabled
+}
+
+// IsOpenAIWSAllowStoreRecoveryEnabled 返回账号级 store 恢复开关。
+// 字段:accounts.extra.openai_ws_allow_store_recovery。
+func (a *Account) IsOpenAIWSAllowStoreRecoveryEnabled() bool {
+ if a == nil || !a.IsOpenAI() || a.Extra == nil {
+ return false
+ }
+ enabled, ok := a.Extra["openai_ws_allow_store_recovery"].(bool)
+ return ok && enabled
+}
+
// IsOpenAIOAuthPassthroughEnabled 兼容旧接口,等价于 OAuth 账号的 IsOpenAIPassthroughEnabled。
func (a *Account) IsOpenAIOAuthPassthroughEnabled() bool {
return a != nil && a.IsOpenAIOAuth() && a.IsOpenAIPassthroughEnabled()
@@ -911,6 +1137,80 @@ func (a *Account) GetSessionIdleTimeoutMinutes() int {
return 5
}
+// GetBaseRPM 获取基础 RPM 限制
+// 返回 0 表示未启用(负数视为无效配置,按 0 处理)
+func (a *Account) GetBaseRPM() int {
+ if a.Extra == nil {
+ return 0
+ }
+ if v, ok := a.Extra["base_rpm"]; ok {
+ val := parseExtraInt(v)
+ if val > 0 {
+ return val
+ }
+ }
+ return 0
+}
+
+// GetRPMStrategy 获取 RPM 策略
+// "tiered" = 三区模型(默认), "sticky_exempt" = 粘性豁免
+func (a *Account) GetRPMStrategy() string {
+ if a.Extra == nil {
+ return "tiered"
+ }
+ if v, ok := a.Extra["rpm_strategy"]; ok {
+ if s, ok := v.(string); ok && s == "sticky_exempt" {
+ return "sticky_exempt"
+ }
+ }
+ return "tiered"
+}
+
+// GetRPMStickyBuffer 获取 RPM 粘性缓冲数量
+// tiered 模式下的黄区大小,默认为 base_rpm 的 20%(至少 1)
+func (a *Account) GetRPMStickyBuffer() int {
+ if a.Extra == nil {
+ return 0
+ }
+ if v, ok := a.Extra["rpm_sticky_buffer"]; ok {
+ val := parseExtraInt(v)
+ if val > 0 {
+ return val
+ }
+ }
+ base := a.GetBaseRPM()
+ buffer := base / 5
+ if buffer < 1 && base > 0 {
+ buffer = 1
+ }
+ return buffer
+}
+
+// CheckRPMSchedulability 根据当前 RPM 计数检查调度状态
+// 复用 WindowCostSchedulability 三态:Schedulable / StickyOnly / NotSchedulable
+func (a *Account) CheckRPMSchedulability(currentRPM int) WindowCostSchedulability {
+ baseRPM := a.GetBaseRPM()
+ if baseRPM <= 0 {
+ return WindowCostSchedulable
+ }
+
+ if currentRPM < baseRPM {
+ return WindowCostSchedulable
+ }
+
+ strategy := a.GetRPMStrategy()
+ if strategy == "sticky_exempt" {
+ return WindowCostStickyOnly // 粘性豁免无红区
+ }
+
+ // tiered: 黄区 + 红区
+ buffer := a.GetRPMStickyBuffer()
+ if currentRPM < baseRPM+buffer {
+ return WindowCostStickyOnly
+ }
+ return WindowCostNotSchedulable
+}
+
// CheckWindowCostSchedulability 根据当前窗口费用检查调度状态
// - 费用 < 阈值: WindowCostSchedulable(可正常调度)
// - 费用 >= 阈值 且 < 阈值+预留: WindowCostStickyOnly(仅粘性会话)
@@ -974,6 +1274,12 @@ func parseExtraFloat64(value any) float64 {
}
// parseExtraInt 从 extra 字段解析 int 值
+// ParseExtraInt 从 extra 字段的 any 值解析为 int。
+// 支持 int, int64, float64, json.Number, string 类型,无法解析时返回 0。
+func ParseExtraInt(value any) int {
+ return parseExtraInt(value)
+}
+
func parseExtraInt(value any) int {
switch v := value.(type) {
case int:
diff --git a/backend/internal/service/account_openai_passthrough_test.go b/backend/internal/service/account_openai_passthrough_test.go
index 59f8cd8c..a85c68ec 100644
--- a/backend/internal/service/account_openai_passthrough_test.go
+++ b/backend/internal/service/account_openai_passthrough_test.go
@@ -134,3 +134,161 @@ func TestAccount_IsCodexCLIOnlyEnabled(t *testing.T) {
require.False(t, otherPlatform.IsCodexCLIOnlyEnabled())
})
}
+
+func TestAccount_IsOpenAIResponsesWebSocketV2Enabled(t *testing.T) {
+ t.Run("OAuth使用OAuth专用开关", func(t *testing.T) {
+ account := &Account{
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Extra: map[string]any{
+ "openai_oauth_responses_websockets_v2_enabled": true,
+ },
+ }
+ require.True(t, account.IsOpenAIResponsesWebSocketV2Enabled())
+ })
+
+ t.Run("API Key使用API Key专用开关", func(t *testing.T) {
+ account := &Account{
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Extra: map[string]any{
+ "openai_apikey_responses_websockets_v2_enabled": true,
+ },
+ }
+ require.True(t, account.IsOpenAIResponsesWebSocketV2Enabled())
+ })
+
+ t.Run("OAuth账号不会读取API Key专用开关", func(t *testing.T) {
+ account := &Account{
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Extra: map[string]any{
+ "openai_apikey_responses_websockets_v2_enabled": true,
+ },
+ }
+ require.False(t, account.IsOpenAIResponsesWebSocketV2Enabled())
+ })
+
+ t.Run("分类型新键优先于兼容键", func(t *testing.T) {
+ account := &Account{
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Extra: map[string]any{
+ "openai_oauth_responses_websockets_v2_enabled": false,
+ "responses_websockets_v2_enabled": true,
+ "openai_ws_enabled": true,
+ },
+ }
+ require.False(t, account.IsOpenAIResponsesWebSocketV2Enabled())
+ })
+
+ t.Run("分类型键缺失时回退兼容键", func(t *testing.T) {
+ account := &Account{
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Extra: map[string]any{
+ "responses_websockets_v2_enabled": true,
+ },
+ }
+ require.True(t, account.IsOpenAIResponsesWebSocketV2Enabled())
+ })
+
+ t.Run("非OpenAI账号默认关闭", func(t *testing.T) {
+ account := &Account{
+ Platform: PlatformAnthropic,
+ Type: AccountTypeAPIKey,
+ Extra: map[string]any{
+ "responses_websockets_v2_enabled": true,
+ },
+ }
+ require.False(t, account.IsOpenAIResponsesWebSocketV2Enabled())
+ })
+}
+
+func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) {
+ t.Run("default fallback to shared", func(t *testing.T) {
+ account := &Account{
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Extra: map[string]any{},
+ }
+ require.Equal(t, OpenAIWSIngressModeShared, account.ResolveOpenAIResponsesWebSocketV2Mode(""))
+ require.Equal(t, OpenAIWSIngressModeShared, account.ResolveOpenAIResponsesWebSocketV2Mode("invalid"))
+ })
+
+ t.Run("oauth mode field has highest priority", func(t *testing.T) {
+ account := &Account{
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Extra: map[string]any{
+ "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated,
+ "openai_oauth_responses_websockets_v2_enabled": false,
+ "responses_websockets_v2_enabled": false,
+ },
+ }
+ require.Equal(t, OpenAIWSIngressModeDedicated, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeShared))
+ })
+
+ t.Run("legacy enabled maps to shared", func(t *testing.T) {
+ account := &Account{
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Extra: map[string]any{
+ "responses_websockets_v2_enabled": true,
+ },
+ }
+ require.Equal(t, OpenAIWSIngressModeShared, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeOff))
+ })
+
+ t.Run("legacy disabled maps to off", func(t *testing.T) {
+ account := &Account{
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Extra: map[string]any{
+ "openai_apikey_responses_websockets_v2_enabled": false,
+ "responses_websockets_v2_enabled": true,
+ },
+ }
+ require.Equal(t, OpenAIWSIngressModeOff, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeShared))
+ })
+
+ t.Run("non openai always off", func(t *testing.T) {
+ account := &Account{
+ Platform: PlatformAnthropic,
+ Type: AccountTypeOAuth,
+ Extra: map[string]any{
+ "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated,
+ },
+ }
+ require.Equal(t, OpenAIWSIngressModeOff, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeDedicated))
+ })
+}
+
+func TestAccount_OpenAIWSExtraFlags(t *testing.T) {
+ account := &Account{
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Extra: map[string]any{
+ "openai_ws_force_http": true,
+ "openai_ws_allow_store_recovery": true,
+ },
+ }
+ require.True(t, account.IsOpenAIWSForceHTTPEnabled())
+ require.True(t, account.IsOpenAIWSAllowStoreRecoveryEnabled())
+
+ off := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth, Extra: map[string]any{}}
+ require.False(t, off.IsOpenAIWSForceHTTPEnabled())
+ require.False(t, off.IsOpenAIWSAllowStoreRecoveryEnabled())
+
+ var nilAccount *Account
+ require.False(t, nilAccount.IsOpenAIWSAllowStoreRecoveryEnabled())
+
+ nonOpenAI := &Account{
+ Platform: PlatformAnthropic,
+ Type: AccountTypeOAuth,
+ Extra: map[string]any{
+ "openai_ws_allow_store_recovery": true,
+ },
+ }
+ require.False(t, nonOpenAI.IsOpenAIWSAllowStoreRecoveryEnabled())
+}
diff --git a/backend/internal/service/account_rpm_test.go b/backend/internal/service/account_rpm_test.go
new file mode 100644
index 00000000..9d91f3e0
--- /dev/null
+++ b/backend/internal/service/account_rpm_test.go
@@ -0,0 +1,120 @@
+package service
+
+import (
+ "encoding/json"
+ "testing"
+)
+
+func TestGetBaseRPM(t *testing.T) {
+ tests := []struct {
+ name string
+ extra map[string]any
+ expected int
+ }{
+ {"nil extra", nil, 0},
+ {"no key", map[string]any{}, 0},
+ {"zero", map[string]any{"base_rpm": 0}, 0},
+ {"int value", map[string]any{"base_rpm": 15}, 15},
+ {"float value", map[string]any{"base_rpm": 15.0}, 15},
+ {"string value", map[string]any{"base_rpm": "15"}, 15},
+ {"negative value", map[string]any{"base_rpm": -5}, 0},
+ {"int64 value", map[string]any{"base_rpm": int64(20)}, 20},
+ {"json.Number value", map[string]any{"base_rpm": json.Number("25")}, 25},
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ a := &Account{Extra: tt.extra}
+ if got := a.GetBaseRPM(); got != tt.expected {
+ t.Errorf("GetBaseRPM() = %d, want %d", got, tt.expected)
+ }
+ })
+ }
+}
+
+func TestGetRPMStrategy(t *testing.T) {
+ tests := []struct {
+ name string
+ extra map[string]any
+ expected string
+ }{
+ {"nil extra", nil, "tiered"},
+ {"no key", map[string]any{}, "tiered"},
+ {"tiered", map[string]any{"rpm_strategy": "tiered"}, "tiered"},
+ {"sticky_exempt", map[string]any{"rpm_strategy": "sticky_exempt"}, "sticky_exempt"},
+ {"invalid", map[string]any{"rpm_strategy": "foobar"}, "tiered"},
+ {"empty string fallback", map[string]any{"rpm_strategy": ""}, "tiered"},
+ {"numeric value fallback", map[string]any{"rpm_strategy": 123}, "tiered"},
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ a := &Account{Extra: tt.extra}
+ if got := a.GetRPMStrategy(); got != tt.expected {
+ t.Errorf("GetRPMStrategy() = %q, want %q", got, tt.expected)
+ }
+ })
+ }
+}
+
+func TestCheckRPMSchedulability(t *testing.T) {
+ tests := []struct {
+ name string
+ extra map[string]any
+ currentRPM int
+ expected WindowCostSchedulability
+ }{
+ {"disabled", map[string]any{}, 100, WindowCostSchedulable},
+ {"green zone", map[string]any{"base_rpm": 15}, 10, WindowCostSchedulable},
+ {"yellow zone tiered", map[string]any{"base_rpm": 15}, 15, WindowCostStickyOnly},
+ {"red zone tiered", map[string]any{"base_rpm": 15}, 18, WindowCostNotSchedulable},
+ {"sticky_exempt at limit", map[string]any{"base_rpm": 15, "rpm_strategy": "sticky_exempt"}, 15, WindowCostStickyOnly},
+ {"sticky_exempt over limit", map[string]any{"base_rpm": 15, "rpm_strategy": "sticky_exempt"}, 100, WindowCostStickyOnly},
+ {"custom buffer", map[string]any{"base_rpm": 10, "rpm_sticky_buffer": 5}, 14, WindowCostStickyOnly},
+ {"custom buffer red", map[string]any{"base_rpm": 10, "rpm_sticky_buffer": 5}, 15, WindowCostNotSchedulable},
+ {"base_rpm=1 green", map[string]any{"base_rpm": 1}, 0, WindowCostSchedulable},
+ {"base_rpm=1 yellow (at limit)", map[string]any{"base_rpm": 1}, 1, WindowCostStickyOnly},
+ {"base_rpm=1 red (at limit+buffer)", map[string]any{"base_rpm": 1}, 2, WindowCostNotSchedulable},
+ {"negative currentRPM", map[string]any{"base_rpm": 15}, -1, WindowCostSchedulable},
+ {"base_rpm negative disabled", map[string]any{"base_rpm": -5}, 10, WindowCostSchedulable},
+ {"very high currentRPM", map[string]any{"base_rpm": 10}, 9999, WindowCostNotSchedulable},
+ {"sticky_exempt very high currentRPM", map[string]any{"base_rpm": 10, "rpm_strategy": "sticky_exempt"}, 9999, WindowCostStickyOnly},
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ a := &Account{Extra: tt.extra}
+ if got := a.CheckRPMSchedulability(tt.currentRPM); got != tt.expected {
+ t.Errorf("CheckRPMSchedulability(%d) = %d, want %d", tt.currentRPM, got, tt.expected)
+ }
+ })
+ }
+}
+
+func TestGetRPMStickyBuffer(t *testing.T) {
+ tests := []struct {
+ name string
+ extra map[string]any
+ expected int
+ }{
+ {"nil extra", nil, 0},
+ {"no keys", map[string]any{}, 0},
+ {"base_rpm=0", map[string]any{"base_rpm": 0}, 0},
+ {"base_rpm=1 min buffer 1", map[string]any{"base_rpm": 1}, 1},
+ {"base_rpm=4 min buffer 1", map[string]any{"base_rpm": 4}, 1},
+ {"base_rpm=5 buffer 1", map[string]any{"base_rpm": 5}, 1},
+ {"base_rpm=10 buffer 2", map[string]any{"base_rpm": 10}, 2},
+ {"base_rpm=15 buffer 3", map[string]any{"base_rpm": 15}, 3},
+ {"base_rpm=100 buffer 20", map[string]any{"base_rpm": 100}, 20},
+ {"custom buffer=5", map[string]any{"base_rpm": 10, "rpm_sticky_buffer": 5}, 5},
+ {"custom buffer=0 fallback to default", map[string]any{"base_rpm": 10, "rpm_sticky_buffer": 0}, 2},
+ {"custom buffer negative fallback", map[string]any{"base_rpm": 10, "rpm_sticky_buffer": -1}, 2},
+ {"custom buffer with float", map[string]any{"base_rpm": 10, "rpm_sticky_buffer": float64(7)}, 7},
+ {"json.Number base_rpm", map[string]any{"base_rpm": json.Number("10")}, 2},
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ a := &Account{Extra: tt.extra}
+ if got := a.GetRPMStickyBuffer(); got != tt.expected {
+ t.Errorf("GetRPMStickyBuffer() = %d, want %d", got, tt.expected)
+ }
+ })
+ }
+}
diff --git a/backend/internal/service/account_service.go b/backend/internal/service/account_service.go
index b301049f..a3707184 100644
--- a/backend/internal/service/account_service.go
+++ b/backend/internal/service/account_service.go
@@ -119,6 +119,10 @@ type AccountService struct {
groupRepo GroupRepository
}
+type groupExistenceBatchChecker interface {
+ ExistsByIDs(ctx context.Context, ids []int64) (map[int64]bool, error)
+}
+
// NewAccountService 创建账号服务实例
func NewAccountService(accountRepo AccountRepository, groupRepo GroupRepository) *AccountService {
return &AccountService{
@@ -131,11 +135,8 @@ func NewAccountService(accountRepo AccountRepository, groupRepo GroupRepository)
func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) (*Account, error) {
// 验证分组是否存在(如果指定了分组)
if len(req.GroupIDs) > 0 {
- for _, groupID := range req.GroupIDs {
- _, err := s.groupRepo.GetByID(ctx, groupID)
- if err != nil {
- return nil, fmt.Errorf("get group: %w", err)
- }
+ if err := s.validateGroupIDsExist(ctx, req.GroupIDs); err != nil {
+ return nil, err
}
}
@@ -256,11 +257,8 @@ func (s *AccountService) Update(ctx context.Context, id int64, req UpdateAccount
// 先验证分组是否存在(在任何写操作之前)
if req.GroupIDs != nil {
- for _, groupID := range *req.GroupIDs {
- _, err := s.groupRepo.GetByID(ctx, groupID)
- if err != nil {
- return nil, fmt.Errorf("get group: %w", err)
- }
+ if err := s.validateGroupIDsExist(ctx, *req.GroupIDs); err != nil {
+ return nil, err
}
}
@@ -300,6 +298,39 @@ func (s *AccountService) Delete(ctx context.Context, id int64) error {
return nil
}
+func (s *AccountService) validateGroupIDsExist(ctx context.Context, groupIDs []int64) error {
+ if len(groupIDs) == 0 {
+ return nil
+ }
+ if s.groupRepo == nil {
+ return fmt.Errorf("group repository not configured")
+ }
+
+ if batchChecker, ok := s.groupRepo.(groupExistenceBatchChecker); ok {
+ existsByID, err := batchChecker.ExistsByIDs(ctx, groupIDs)
+ if err != nil {
+ return fmt.Errorf("check groups exists: %w", err)
+ }
+ for _, groupID := range groupIDs {
+ if groupID <= 0 {
+ return fmt.Errorf("get group: %w", ErrGroupNotFound)
+ }
+ if !existsByID[groupID] {
+ return fmt.Errorf("get group: %w", ErrGroupNotFound)
+ }
+ }
+ return nil
+ }
+
+ for _, groupID := range groupIDs {
+ _, err := s.groupRepo.GetByID(ctx, groupID)
+ if err != nil {
+ return fmt.Errorf("get group: %w", err)
+ }
+ }
+ return nil
+}
+
// UpdateStatus 更新账号状态
func (s *AccountService) UpdateStatus(ctx context.Context, id int64, status string, errorMessage string) error {
account, err := s.accountRepo.GetByID(ctx, id)
diff --git a/backend/internal/service/account_test_service.go b/backend/internal/service/account_test_service.go
index a507efb4..c55e418d 100644
--- a/backend/internal/service/account_test_service.go
+++ b/backend/internal/service/account_test_service.go
@@ -598,9 +598,102 @@ func ceilSeconds(d time.Duration) int {
return sec
}
+// testSoraAPIKeyAccountConnection 测试 Sora apikey 类型账号的连通性。
+// 向上游 base_url 发送轻量级 prompt-enhance 请求验证连通性和 API Key 有效性。
+func (s *AccountTestService) testSoraAPIKeyAccountConnection(c *gin.Context, account *Account) error {
+ ctx := c.Request.Context()
+
+ apiKey := account.GetCredential("api_key")
+ if apiKey == "" {
+ return s.sendErrorAndEnd(c, "Sora apikey 账号缺少 api_key 凭证")
+ }
+
+ baseURL := account.GetBaseURL()
+ if baseURL == "" {
+ return s.sendErrorAndEnd(c, "Sora apikey 账号缺少 base_url")
+ }
+
+ // 验证 base_url 格式
+ normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
+ if err != nil {
+ return s.sendErrorAndEnd(c, fmt.Sprintf("base_url 无效: %s", err.Error()))
+ }
+ upstreamURL := strings.TrimSuffix(normalizedBaseURL, "/") + "/sora/v1/chat/completions"
+
+ // 设置 SSE 头
+ c.Writer.Header().Set("Content-Type", "text/event-stream")
+ c.Writer.Header().Set("Cache-Control", "no-cache")
+ c.Writer.Header().Set("Connection", "keep-alive")
+ c.Writer.Header().Set("X-Accel-Buffering", "no")
+ c.Writer.Flush()
+
+ if wait, ok := s.acquireSoraTestPermit(account.ID); !ok {
+ msg := fmt.Sprintf("Sora 账号测试过于频繁,请 %d 秒后重试", ceilSeconds(wait))
+ return s.sendErrorAndEnd(c, msg)
+ }
+
+ s.sendEvent(c, TestEvent{Type: "test_start", Model: "sora-upstream"})
+
+ // 构建轻量级 prompt-enhance 请求作为连通性测试
+ testPayload := map[string]any{
+ "model": "prompt-enhance-short-10s",
+ "messages": []map[string]string{{"role": "user", "content": "test"}},
+ "stream": false,
+ }
+ payloadBytes, _ := json.Marshal(testPayload)
+
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, upstreamURL, bytes.NewReader(payloadBytes))
+ if err != nil {
+ return s.sendErrorAndEnd(c, "构建测试请求失败")
+ }
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("Authorization", "Bearer "+apiKey)
+
+ // 获取代理 URL
+ proxyURL := ""
+ if account.ProxyID != nil && account.Proxy != nil {
+ proxyURL = account.Proxy.URL()
+ }
+
+ resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
+ if err != nil {
+ return s.sendErrorAndEnd(c, fmt.Sprintf("上游连接失败: %s", err.Error()))
+ }
+ defer func() { _ = resp.Body.Close() }()
+
+ respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 64*1024))
+
+ if resp.StatusCode == http.StatusOK {
+ s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("上游连接成功 (%s)", upstreamURL)})
+ s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("API Key 有效 (HTTP %d)", resp.StatusCode)})
+ s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
+ return nil
+ }
+
+ if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden {
+ return s.sendErrorAndEnd(c, fmt.Sprintf("上游认证失败 (HTTP %d),请检查 API Key 是否正确", resp.StatusCode))
+ }
+
+ // 其他错误但能连通(如 400 参数错误)也算连通性测试通过
+ if resp.StatusCode == http.StatusBadRequest {
+ s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("上游连接成功 (%s)", upstreamURL)})
+ s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("API Key 有效(上游返回 %d,参数校验错误属正常)", resp.StatusCode)})
+ s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
+ return nil
+ }
+
+ return s.sendErrorAndEnd(c, fmt.Sprintf("上游返回异常 HTTP %d: %s", resp.StatusCode, truncateSoraErrorBody(respBody, 256)))
+}
+
// testSoraAccountConnection 测试 Sora 账号的连接
-// 调用 /backend/me 接口验证 access_token 有效性(不需要 Sentinel Token)
+// OAuth 类型:调用 /backend/me 接口验证 access_token 有效性
+// APIKey 类型:向上游 base_url 发送轻量级 prompt-enhance 请求验证连通性
func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account *Account) error {
+ // apikey 类型走独立测试流程
+ if account.Type == AccountTypeAPIKey {
+ return s.testSoraAPIKeyAccountConnection(c, account)
+ }
+
ctx := c.Request.Context()
recorder := &soraProbeRecorder{}
diff --git a/backend/internal/service/account_usage_service.go b/backend/internal/service/account_usage_service.go
index a363a790..6dee6c13 100644
--- a/backend/internal/service/account_usage_service.go
+++ b/backend/internal/service/account_usage_service.go
@@ -9,7 +9,9 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
+ "golang.org/x/sync/errgroup"
)
type UsageLogRepository interface {
@@ -33,8 +35,9 @@ type UsageLogRepository interface {
// Admin dashboard stats
GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error)
- GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error)
- GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.ModelStat, error)
+ GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error)
+ GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.ModelStat, error)
+ GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error)
GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error)
GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error)
GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error)
@@ -62,6 +65,10 @@ type UsageLogRepository interface {
GetDailyStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) ([]map[string]any, error)
}
+type accountWindowStatsBatchReader interface {
+ GetAccountWindowStatsBatch(ctx context.Context, accountIDs []int64, startTime time.Time) (map[int64]*usagestats.AccountStats, error)
+}
+
// apiUsageCache 缓存从 Anthropic API 获取的使用率数据(utilization, resets_at)
type apiUsageCache struct {
response *ClaudeUsageResponse
@@ -297,7 +304,7 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou
}
dayStart := geminiDailyWindowStart(now)
- stats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, dayStart, now, 0, 0, account.ID, 0, nil, nil)
+ stats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, dayStart, now, 0, 0, account.ID, 0, nil, nil, nil)
if err != nil {
return nil, fmt.Errorf("get gemini usage stats failed: %w", err)
}
@@ -319,7 +326,7 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou
// Minute window (RPM) - fixed-window approximation: current minute [truncate(now), truncate(now)+1m)
minuteStart := now.Truncate(time.Minute)
minuteResetAt := minuteStart.Add(time.Minute)
- minuteStats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, minuteStart, now, 0, 0, account.ID, 0, nil, nil)
+ minuteStats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, minuteStart, now, 0, 0, account.ID, 0, nil, nil, nil)
if err != nil {
return nil, fmt.Errorf("get gemini minute usage stats failed: %w", err)
}
@@ -440,6 +447,78 @@ func (s *AccountUsageService) GetTodayStats(ctx context.Context, accountID int64
}, nil
}
+// GetTodayStatsBatch 批量获取账号今日统计,优先走批量 SQL,失败时回退单账号查询。
+func (s *AccountUsageService) GetTodayStatsBatch(ctx context.Context, accountIDs []int64) (map[int64]*WindowStats, error) {
+ uniqueIDs := make([]int64, 0, len(accountIDs))
+ seen := make(map[int64]struct{}, len(accountIDs))
+ for _, accountID := range accountIDs {
+ if accountID <= 0 {
+ continue
+ }
+ if _, exists := seen[accountID]; exists {
+ continue
+ }
+ seen[accountID] = struct{}{}
+ uniqueIDs = append(uniqueIDs, accountID)
+ }
+
+ result := make(map[int64]*WindowStats, len(uniqueIDs))
+ if len(uniqueIDs) == 0 {
+ return result, nil
+ }
+
+ startTime := timezone.Today()
+ if batchReader, ok := s.usageLogRepo.(accountWindowStatsBatchReader); ok {
+ statsByAccount, err := batchReader.GetAccountWindowStatsBatch(ctx, uniqueIDs, startTime)
+ if err == nil {
+ for _, accountID := range uniqueIDs {
+ result[accountID] = windowStatsFromAccountStats(statsByAccount[accountID])
+ }
+ return result, nil
+ }
+ }
+
+ var mu sync.Mutex
+ g, gctx := errgroup.WithContext(ctx)
+ g.SetLimit(8)
+
+ for _, accountID := range uniqueIDs {
+ id := accountID
+ g.Go(func() error {
+ stats, err := s.usageLogRepo.GetAccountWindowStats(gctx, id, startTime)
+ if err != nil {
+ return nil
+ }
+ mu.Lock()
+ result[id] = windowStatsFromAccountStats(stats)
+ mu.Unlock()
+ return nil
+ })
+ }
+
+ _ = g.Wait()
+
+ for _, accountID := range uniqueIDs {
+ if _, ok := result[accountID]; !ok {
+ result[accountID] = &WindowStats{}
+ }
+ }
+ return result, nil
+}
+
+func windowStatsFromAccountStats(stats *usagestats.AccountStats) *WindowStats {
+ if stats == nil {
+ return &WindowStats{}
+ }
+ return &WindowStats{
+ Requests: stats.Requests,
+ Tokens: stats.Tokens,
+ Cost: stats.Cost,
+ StandardCost: stats.StandardCost,
+ UserCost: stats.UserCost,
+ }
+}
+
func (s *AccountUsageService) GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.AccountUsageStatsResponse, error) {
stats, err := s.usageLogRepo.GetAccountUsageStats(ctx, accountID, startTime, endTime)
if err != nil {
diff --git a/backend/internal/service/account_wildcard_test.go b/backend/internal/service/account_wildcard_test.go
index 6a9acc68..7782f948 100644
--- a/backend/internal/service/account_wildcard_test.go
+++ b/backend/internal/service/account_wildcard_test.go
@@ -314,3 +314,72 @@ func TestAccountGetModelMapping_AntigravityRespectsWildcardOverride(t *testing.T
t.Fatalf("expected wildcard mapping to stay effective, got: %q", mapped)
}
}
+
+func TestAccountGetModelMapping_CacheInvalidatesOnCredentialsReplace(t *testing.T) {
+ account := &Account{
+ Credentials: map[string]any{
+ "model_mapping": map[string]any{
+ "claude-3-5-sonnet": "upstream-a",
+ },
+ },
+ }
+
+ first := account.GetModelMapping()
+ if first["claude-3-5-sonnet"] != "upstream-a" {
+ t.Fatalf("unexpected first mapping: %v", first)
+ }
+
+ account.Credentials = map[string]any{
+ "model_mapping": map[string]any{
+ "claude-3-5-sonnet": "upstream-b",
+ },
+ }
+ second := account.GetModelMapping()
+ if second["claude-3-5-sonnet"] != "upstream-b" {
+ t.Fatalf("expected cache invalidated after credentials replace, got: %v", second)
+ }
+}
+
+func TestAccountGetModelMapping_CacheInvalidatesOnMappingLenChange(t *testing.T) {
+ rawMapping := map[string]any{
+ "claude-sonnet": "sonnet-a",
+ }
+ account := &Account{
+ Credentials: map[string]any{
+ "model_mapping": rawMapping,
+ },
+ }
+
+ first := account.GetModelMapping()
+ if len(first) != 1 {
+ t.Fatalf("unexpected first mapping length: %d", len(first))
+ }
+
+ rawMapping["claude-opus"] = "opus-b"
+ second := account.GetModelMapping()
+ if second["claude-opus"] != "opus-b" {
+ t.Fatalf("expected cache invalidated after mapping len change, got: %v", second)
+ }
+}
+
+func TestAccountGetModelMapping_CacheInvalidatesOnInPlaceValueChange(t *testing.T) {
+ rawMapping := map[string]any{
+ "claude-sonnet": "sonnet-a",
+ }
+ account := &Account{
+ Credentials: map[string]any{
+ "model_mapping": rawMapping,
+ },
+ }
+
+ first := account.GetModelMapping()
+ if first["claude-sonnet"] != "sonnet-a" {
+ t.Fatalf("unexpected first mapping: %v", first)
+ }
+
+ rawMapping["claude-sonnet"] = "sonnet-b"
+ second := account.GetModelMapping()
+ if second["claude-sonnet"] != "sonnet-b" {
+ t.Fatalf("expected cache invalidated after in-place value change, got: %v", second)
+ }
+}
diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go
index 47339661..bdd1aa4a 100644
--- a/backend/internal/service/admin_service.go
+++ b/backend/internal/service/admin_service.go
@@ -9,6 +9,8 @@ import (
"strings"
"time"
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
@@ -42,6 +44,9 @@ type AdminService interface {
GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]APIKey, int64, error)
UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error
+ // API Key management (admin)
+ AdminUpdateAPIKeyGroupID(ctx context.Context, keyID int64, groupID *int64) (*AdminUpdateAPIKeyGroupIDResult, error)
+
// Account management
ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]Account, int64, error)
GetAccount(ctx context.Context, id int64) (*Account, error)
@@ -83,13 +88,14 @@ type AdminService interface {
// CreateUserInput represents input for creating a new user via admin operations.
type CreateUserInput struct {
- Email string
- Password string
- Username string
- Notes string
- Balance float64
- Concurrency int
- AllowedGroups []int64
+ Email string
+ Password string
+ Username string
+ Notes string
+ Balance float64
+ Concurrency int
+ AllowedGroups []int64
+ SoraStorageQuotaBytes int64
}
type UpdateUserInput struct {
@@ -103,7 +109,8 @@ type UpdateUserInput struct {
AllowedGroups *[]int64 // 使用指针区分"未提供"和"设置为空数组"
// GroupRates 用户专属分组倍率配置
// map[groupID]*rate,nil 表示删除该分组的专属倍率
- GroupRates map[int64]*float64
+ GroupRates map[int64]*float64
+ SoraStorageQuotaBytes *int64
}
type CreateGroupInput struct {
@@ -135,6 +142,8 @@ type CreateGroupInput struct {
MCPXMLInject *bool
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes []string
+ // Sora 存储配额
+ SoraStorageQuotaBytes int64
// 从指定分组复制账号(创建分组后在同一事务内绑定)
CopyAccountsFromGroupIDs []int64
}
@@ -169,6 +178,8 @@ type UpdateGroupInput struct {
MCPXMLInject *bool
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes *[]string
+ // Sora 存储配额
+ SoraStorageQuotaBytes *int64
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
CopyAccountsFromGroupIDs []int64
}
@@ -236,6 +247,14 @@ type BulkUpdateAccountResult struct {
Error string `json:"error,omitempty"`
}
+// AdminUpdateAPIKeyGroupIDResult is the result of AdminUpdateAPIKeyGroupID.
+type AdminUpdateAPIKeyGroupIDResult struct {
+ APIKey *APIKey
+ AutoGrantedGroupAccess bool // true if a new exclusive group permission was auto-added
+ GrantedGroupID *int64 // the group ID that was auto-granted
+ GrantedGroupName string // the group name that was auto-granted
+}
+
// BulkUpdateAccountsResult is the aggregated response for bulk updates.
type BulkUpdateAccountsResult struct {
Success int `json:"success"`
@@ -400,6 +419,17 @@ type adminServiceImpl struct {
proxyProber ProxyExitInfoProber
proxyLatencyCache ProxyLatencyCache
authCacheInvalidator APIKeyAuthCacheInvalidator
+ entClient *dbent.Client // 用于开启数据库事务
+ settingService *SettingService
+ defaultSubAssigner DefaultSubscriptionAssigner
+}
+
+type userGroupRateBatchReader interface {
+ GetByUserIDs(ctx context.Context, userIDs []int64) (map[int64]map[int64]float64, error)
+}
+
+type groupExistenceBatchReader interface {
+ ExistsByIDs(ctx context.Context, ids []int64) (map[int64]bool, error)
}
// NewAdminService creates a new AdminService
@@ -416,6 +446,9 @@ func NewAdminService(
proxyProber ProxyExitInfoProber,
proxyLatencyCache ProxyLatencyCache,
authCacheInvalidator APIKeyAuthCacheInvalidator,
+ entClient *dbent.Client,
+ settingService *SettingService,
+ defaultSubAssigner DefaultSubscriptionAssigner,
) AdminService {
return &adminServiceImpl{
userRepo: userRepo,
@@ -430,6 +463,9 @@ func NewAdminService(
proxyProber: proxyProber,
proxyLatencyCache: proxyLatencyCache,
authCacheInvalidator: authCacheInvalidator,
+ entClient: entClient,
+ settingService: settingService,
+ defaultSubAssigner: defaultSubAssigner,
}
}
@@ -442,18 +478,43 @@ func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, fi
}
// 批量加载用户专属分组倍率
if s.userGroupRateRepo != nil && len(users) > 0 {
- for i := range users {
- rates, err := s.userGroupRateRepo.GetByUserID(ctx, users[i].ID)
- if err != nil {
- logger.LegacyPrintf("service.admin", "failed to load user group rates: user_id=%d err=%v", users[i].ID, err)
- continue
+ if batchRepo, ok := s.userGroupRateRepo.(userGroupRateBatchReader); ok {
+ userIDs := make([]int64, 0, len(users))
+ for i := range users {
+ userIDs = append(userIDs, users[i].ID)
}
- users[i].GroupRates = rates
+ ratesByUser, err := batchRepo.GetByUserIDs(ctx, userIDs)
+ if err != nil {
+ logger.LegacyPrintf("service.admin", "failed to load user group rates in batch: err=%v", err)
+ s.loadUserGroupRatesOneByOne(ctx, users)
+ } else {
+ for i := range users {
+ if rates, ok := ratesByUser[users[i].ID]; ok {
+ users[i].GroupRates = rates
+ }
+ }
+ }
+ } else {
+ s.loadUserGroupRatesOneByOne(ctx, users)
}
}
return users, result.Total, nil
}
+func (s *adminServiceImpl) loadUserGroupRatesOneByOne(ctx context.Context, users []User) {
+ if s.userGroupRateRepo == nil {
+ return
+ }
+ for i := range users {
+ rates, err := s.userGroupRateRepo.GetByUserID(ctx, users[i].ID)
+ if err != nil {
+ logger.LegacyPrintf("service.admin", "failed to load user group rates: user_id=%d err=%v", users[i].ID, err)
+ continue
+ }
+ users[i].GroupRates = rates
+ }
+}
+
func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*User, error) {
user, err := s.userRepo.GetByID(ctx, id)
if err != nil {
@@ -473,14 +534,15 @@ func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*User, error)
func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInput) (*User, error) {
user := &User{
- Email: input.Email,
- Username: input.Username,
- Notes: input.Notes,
- Role: RoleUser, // Always create as regular user, never admin
- Balance: input.Balance,
- Concurrency: input.Concurrency,
- Status: StatusActive,
- AllowedGroups: input.AllowedGroups,
+ Email: input.Email,
+ Username: input.Username,
+ Notes: input.Notes,
+ Role: RoleUser, // Always create as regular user, never admin
+ Balance: input.Balance,
+ Concurrency: input.Concurrency,
+ Status: StatusActive,
+ AllowedGroups: input.AllowedGroups,
+ SoraStorageQuotaBytes: input.SoraStorageQuotaBytes,
}
if err := user.SetPassword(input.Password); err != nil {
return nil, err
@@ -488,9 +550,27 @@ func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInpu
if err := s.userRepo.Create(ctx, user); err != nil {
return nil, err
}
+ s.assignDefaultSubscriptions(ctx, user.ID)
return user, nil
}
+func (s *adminServiceImpl) assignDefaultSubscriptions(ctx context.Context, userID int64) {
+ if s.settingService == nil || s.defaultSubAssigner == nil || userID <= 0 {
+ return
+ }
+ items := s.settingService.GetDefaultSubscriptions(ctx)
+ for _, item := range items {
+ if _, _, err := s.defaultSubAssigner.AssignOrExtendSubscription(ctx, &AssignSubscriptionInput{
+ UserID: userID,
+ GroupID: item.GroupID,
+ ValidityDays: item.ValidityDays,
+ Notes: "auto assigned by default user subscriptions setting",
+ }); err != nil {
+ logger.LegacyPrintf("service.admin", "failed to assign default subscription: user_id=%d group_id=%d err=%v", userID, item.GroupID, err)
+ }
+ }
+}
+
func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error) {
user, err := s.userRepo.GetByID(ctx, id)
if err != nil {
@@ -534,6 +614,10 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
user.AllowedGroups = *input.AllowedGroups
}
+ if input.SoraStorageQuotaBytes != nil {
+ user.SoraStorageQuotaBytes = *input.SoraStorageQuotaBytes
+ }
+
if err := s.userRepo.Update(ctx, user); err != nil {
return nil, err
}
@@ -820,6 +904,7 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
ModelRouting: input.ModelRouting,
MCPXMLInject: mcpXMLInject,
SupportedModelScopes: input.SupportedModelScopes,
+ SoraStorageQuotaBytes: input.SoraStorageQuotaBytes,
}
if err := s.groupRepo.Create(ctx, group); err != nil {
return nil, err
@@ -982,6 +1067,9 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
if input.SoraVideoPricePerRequestHD != nil {
group.SoraVideoPricePerRequestHD = normalizePrice(input.SoraVideoPricePerRequestHD)
}
+ if input.SoraStorageQuotaBytes != nil {
+ group.SoraStorageQuotaBytes = *input.SoraStorageQuotaBytes
+ }
// Claude Code 客户端限制
if input.ClaudeCodeOnly != nil {
@@ -1137,6 +1225,103 @@ func (s *adminServiceImpl) UpdateGroupSortOrders(ctx context.Context, updates []
return s.groupRepo.UpdateSortOrders(ctx, updates)
}
+// AdminUpdateAPIKeyGroupID 管理员修改 API Key 分组绑定
+// groupID: nil=不修改, 指向0=解绑, 指向正整数=绑定到目标分组
+func (s *adminServiceImpl) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID int64, groupID *int64) (*AdminUpdateAPIKeyGroupIDResult, error) {
+ apiKey, err := s.apiKeyRepo.GetByID(ctx, keyID)
+ if err != nil {
+ return nil, err
+ }
+
+ if groupID == nil {
+ // nil 表示不修改,直接返回
+ return &AdminUpdateAPIKeyGroupIDResult{APIKey: apiKey}, nil
+ }
+
+ if *groupID < 0 {
+ return nil, infraerrors.BadRequest("INVALID_GROUP_ID", "group_id must be non-negative")
+ }
+
+ result := &AdminUpdateAPIKeyGroupIDResult{}
+
+ if *groupID == 0 {
+ // 0 表示解绑分组(不修改 user_allowed_groups,避免影响用户其他 Key)
+ apiKey.GroupID = nil
+ apiKey.Group = nil
+ } else {
+ // 验证目标分组存在且状态为 active
+ group, err := s.groupRepo.GetByID(ctx, *groupID)
+ if err != nil {
+ return nil, err
+ }
+ if group.Status != StatusActive {
+ return nil, infraerrors.BadRequest("GROUP_NOT_ACTIVE", "target group is not active")
+ }
+ // 订阅类型分组:不允许通过此 API 直接绑定,需通过订阅管理流程
+ if group.IsSubscriptionType() {
+ return nil, infraerrors.BadRequest("SUBSCRIPTION_GROUP_NOT_ALLOWED", "subscription groups must be managed through the subscription workflow")
+ }
+
+ gid := *groupID
+ apiKey.GroupID = &gid
+ apiKey.Group = group
+
+ // 专属标准分组:使用事务保证「添加分组权限」与「更新 API Key」的原子性
+ if group.IsExclusive {
+ opCtx := ctx
+ var tx *dbent.Tx
+ if s.entClient == nil {
+ logger.LegacyPrintf("service.admin", "Warning: entClient is nil, skipping transaction protection for exclusive group binding")
+ } else {
+ var txErr error
+ tx, txErr = s.entClient.Tx(ctx)
+ if txErr != nil {
+ return nil, fmt.Errorf("begin transaction: %w", txErr)
+ }
+ defer func() { _ = tx.Rollback() }()
+ opCtx = dbent.NewTxContext(ctx, tx)
+ }
+
+ if addErr := s.userRepo.AddGroupToAllowedGroups(opCtx, apiKey.UserID, gid); addErr != nil {
+ return nil, fmt.Errorf("add group to user allowed groups: %w", addErr)
+ }
+ if err := s.apiKeyRepo.Update(opCtx, apiKey); err != nil {
+ return nil, fmt.Errorf("update api key: %w", err)
+ }
+ if tx != nil {
+ if err := tx.Commit(); err != nil {
+ return nil, fmt.Errorf("commit transaction: %w", err)
+ }
+ }
+
+ result.AutoGrantedGroupAccess = true
+ result.GrantedGroupID = &gid
+ result.GrantedGroupName = group.Name
+
+ // 失效认证缓存(在事务提交后执行)
+ if s.authCacheInvalidator != nil {
+ s.authCacheInvalidator.InvalidateAuthCacheByKey(ctx, apiKey.Key)
+ }
+
+ result.APIKey = apiKey
+ return result, nil
+ }
+ }
+
+ // 非专属分组 / 解绑:无需事务,单步更新即可
+ if err := s.apiKeyRepo.Update(ctx, apiKey); err != nil {
+ return nil, fmt.Errorf("update api key: %w", err)
+ }
+
+ // 失效认证缓存
+ if s.authCacheInvalidator != nil {
+ s.authCacheInvalidator.InvalidateAuthCacheByKey(ctx, apiKey.Key)
+ }
+
+ result.APIKey = apiKey
+ return result, nil
+}
+
// Account management implementations
func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]Account, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
@@ -1188,6 +1373,18 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
}
}
+ // Sora apikey 账号的 base_url 必填校验
+ if input.Platform == PlatformSora && input.Type == AccountTypeAPIKey {
+ baseURL, _ := input.Credentials["base_url"].(string)
+ baseURL = strings.TrimSpace(baseURL)
+ if baseURL == "" {
+ return nil, errors.New("sora apikey 账号必须设置 base_url")
+ }
+ if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") {
+ return nil, errors.New("base_url 必须以 http:// 或 https:// 开头")
+ }
+ }
+
account := &Account{
Name: input.Name,
Notes: normalizeAccountNotes(input.Notes),
@@ -1301,12 +1498,22 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
account.AutoPauseOnExpired = *input.AutoPauseOnExpired
}
+ // Sora apikey 账号的 base_url 必填校验
+ if account.Platform == PlatformSora && account.Type == AccountTypeAPIKey {
+ baseURL, _ := account.Credentials["base_url"].(string)
+ baseURL = strings.TrimSpace(baseURL)
+ if baseURL == "" {
+ return nil, errors.New("sora apikey 账号必须设置 base_url")
+ }
+ if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") {
+ return nil, errors.New("base_url 必须以 http:// 或 https:// 开头")
+ }
+ }
+
// 先验证分组是否存在(在任何写操作之前)
if input.GroupIDs != nil {
- for _, groupID := range *input.GroupIDs {
- if _, err := s.groupRepo.GetByID(ctx, groupID); err != nil {
- return nil, fmt.Errorf("get group: %w", err)
- }
+ if err := s.validateGroupIDsExist(ctx, *input.GroupIDs); err != nil {
+ return nil, err
}
// 检查混合渠道风险(除非用户已确认)
@@ -1348,22 +1555,37 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
if len(input.AccountIDs) == 0 {
return result, nil
}
+ if input.GroupIDs != nil {
+ if err := s.validateGroupIDsExist(ctx, *input.GroupIDs); err != nil {
+ return nil, err
+ }
+ }
needMixedChannelCheck := input.GroupIDs != nil && !input.SkipMixedChannelCheck
- // 预加载账号平台信息(混合渠道检查或 Sora 同步需要)。
+ // 预加载账号平台信息(混合渠道检查需要)。
platformByID := map[int64]string{}
if needMixedChannelCheck {
accounts, err := s.accountRepo.GetByIDs(ctx, input.AccountIDs)
if err != nil {
- if needMixedChannelCheck {
- return nil, err
+ return nil, err
+ }
+ for _, account := range accounts {
+ if account != nil {
+ platformByID[account.ID] = account.Platform
}
- } else {
- for _, account := range accounts {
- if account != nil {
- platformByID[account.ID] = account.Platform
- }
+ }
+ }
+
+ // 预检查混合渠道风险:在任何写操作之前,若发现风险立即返回错误。
+ if needMixedChannelCheck {
+ for _, accountID := range input.AccountIDs {
+ platform := platformByID[accountID]
+ if platform == "" {
+ continue
+ }
+ if err := s.checkMixedChannelRisk(ctx, accountID, platform, *input.GroupIDs); err != nil {
+ return nil, err
}
}
}
@@ -1411,31 +1633,6 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
entry := BulkUpdateAccountResult{AccountID: accountID}
if input.GroupIDs != nil {
- // 检查混合渠道风险(除非用户已确认)
- if !input.SkipMixedChannelCheck {
- platform := platformByID[accountID]
- if platform == "" {
- account, err := s.accountRepo.GetByID(ctx, accountID)
- if err != nil {
- entry.Success = false
- entry.Error = err.Error()
- result.Failed++
- result.FailedIDs = append(result.FailedIDs, accountID)
- result.Results = append(result.Results, entry)
- continue
- }
- platform = account.Platform
- }
- if err := s.checkMixedChannelRisk(ctx, accountID, platform, *input.GroupIDs); err != nil {
- entry.Success = false
- entry.Error = err.Error()
- result.Failed++
- result.FailedIDs = append(result.FailedIDs, accountID)
- result.Results = append(result.Results, entry)
- continue
- }
- }
-
if err := s.accountRepo.BindGroups(ctx, accountID, *input.GroupIDs); err != nil {
entry.Success = false
entry.Error = err.Error()
@@ -2115,6 +2312,35 @@ func (s *adminServiceImpl) checkMixedChannelRisk(ctx context.Context, currentAcc
return nil
}
+func (s *adminServiceImpl) validateGroupIDsExist(ctx context.Context, groupIDs []int64) error {
+ if len(groupIDs) == 0 {
+ return nil
+ }
+ if s.groupRepo == nil {
+ return errors.New("group repository not configured")
+ }
+
+ if batchReader, ok := s.groupRepo.(groupExistenceBatchReader); ok {
+ existsByID, err := batchReader.ExistsByIDs(ctx, groupIDs)
+ if err != nil {
+ return fmt.Errorf("check groups exists: %w", err)
+ }
+ for _, groupID := range groupIDs {
+ if groupID <= 0 || !existsByID[groupID] {
+ return fmt.Errorf("get group: %w", ErrGroupNotFound)
+ }
+ }
+ return nil
+ }
+
+ for _, groupID := range groupIDs {
+ if _, err := s.groupRepo.GetByID(ctx, groupID); err != nil {
+ return fmt.Errorf("get group: %w", err)
+ }
+ }
+ return nil
+}
+
// CheckMixedChannelRisk checks whether target groups contain mixed channels for the current account platform.
func (s *adminServiceImpl) CheckMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error {
return s.checkMixedChannelRisk(ctx, currentAccountID, currentAccountPlatform, groupIDs)
diff --git a/backend/internal/service/admin_service_apikey_test.go b/backend/internal/service/admin_service_apikey_test.go
new file mode 100644
index 00000000..9210a786
--- /dev/null
+++ b/backend/internal/service/admin_service_apikey_test.go
@@ -0,0 +1,420 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "errors"
+ "testing"
+ "time"
+
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
+ "github.com/stretchr/testify/require"
+)
+
+// ---------------------------------------------------------------------------
+// Stubs
+// ---------------------------------------------------------------------------
+
+// userRepoStubForGroupUpdate implements UserRepository for AdminUpdateAPIKeyGroupID tests.
+type userRepoStubForGroupUpdate struct {
+ addGroupErr error
+ addGroupCalled bool
+ addedUserID int64
+ addedGroupID int64
+}
+
+func (s *userRepoStubForGroupUpdate) AddGroupToAllowedGroups(_ context.Context, userID int64, groupID int64) error {
+ s.addGroupCalled = true
+ s.addedUserID = userID
+ s.addedGroupID = groupID
+ return s.addGroupErr
+}
+
+func (s *userRepoStubForGroupUpdate) Create(context.Context, *User) error { panic("unexpected") }
+func (s *userRepoStubForGroupUpdate) GetByID(context.Context, int64) (*User, error) { panic("unexpected") }
+func (s *userRepoStubForGroupUpdate) GetByEmail(context.Context, string) (*User, error) { panic("unexpected") }
+func (s *userRepoStubForGroupUpdate) GetFirstAdmin(context.Context) (*User, error) { panic("unexpected") }
+func (s *userRepoStubForGroupUpdate) Update(context.Context, *User) error { panic("unexpected") }
+func (s *userRepoStubForGroupUpdate) Delete(context.Context, int64) error { panic("unexpected") }
+func (s *userRepoStubForGroupUpdate) List(context.Context, pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) {
+ panic("unexpected")
+}
+func (s *userRepoStubForGroupUpdate) ListWithFilters(context.Context, pagination.PaginationParams, UserListFilters) ([]User, *pagination.PaginationResult, error) {
+ panic("unexpected")
+}
+func (s *userRepoStubForGroupUpdate) UpdateBalance(context.Context, int64, float64) error { panic("unexpected") }
+func (s *userRepoStubForGroupUpdate) DeductBalance(context.Context, int64, float64) error { panic("unexpected") }
+func (s *userRepoStubForGroupUpdate) UpdateConcurrency(context.Context, int64, int) error { panic("unexpected") }
+func (s *userRepoStubForGroupUpdate) ExistsByEmail(context.Context, string) (bool, error) { panic("unexpected") }
+func (s *userRepoStubForGroupUpdate) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
+ panic("unexpected")
+}
+func (s *userRepoStubForGroupUpdate) UpdateTotpSecret(context.Context, int64, *string) error { panic("unexpected") }
+func (s *userRepoStubForGroupUpdate) EnableTotp(context.Context, int64) error { panic("unexpected") }
+func (s *userRepoStubForGroupUpdate) DisableTotp(context.Context, int64) error { panic("unexpected") }
+
+// apiKeyRepoStubForGroupUpdate implements APIKeyRepository for AdminUpdateAPIKeyGroupID tests.
+type apiKeyRepoStubForGroupUpdate struct {
+ key *APIKey
+ getErr error
+ updateErr error
+ updated *APIKey // captures what was passed to Update
+}
+
+func (s *apiKeyRepoStubForGroupUpdate) GetByID(_ context.Context, _ int64) (*APIKey, error) {
+ if s.getErr != nil {
+ return nil, s.getErr
+ }
+ clone := *s.key
+ return &clone, nil
+}
+func (s *apiKeyRepoStubForGroupUpdate) Update(_ context.Context, key *APIKey) error {
+ if s.updateErr != nil {
+ return s.updateErr
+ }
+ clone := *key
+ s.updated = &clone
+ return nil
+}
+
+// Unused methods – panic on unexpected call.
+func (s *apiKeyRepoStubForGroupUpdate) Create(context.Context, *APIKey) error { panic("unexpected") }
+func (s *apiKeyRepoStubForGroupUpdate) GetKeyAndOwnerID(context.Context, int64) (string, int64, error) {
+ panic("unexpected")
+}
+func (s *apiKeyRepoStubForGroupUpdate) GetByKey(context.Context, string) (*APIKey, error) {
+ panic("unexpected")
+}
+func (s *apiKeyRepoStubForGroupUpdate) GetByKeyForAuth(context.Context, string) (*APIKey, error) {
+ panic("unexpected")
+}
+func (s *apiKeyRepoStubForGroupUpdate) Delete(context.Context, int64) error { panic("unexpected") }
+func (s *apiKeyRepoStubForGroupUpdate) ListByUserID(context.Context, int64, pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
+ panic("unexpected")
+}
+func (s *apiKeyRepoStubForGroupUpdate) VerifyOwnership(context.Context, int64, []int64) ([]int64, error) {
+ panic("unexpected")
+}
+func (s *apiKeyRepoStubForGroupUpdate) CountByUserID(context.Context, int64) (int64, error) {
+ panic("unexpected")
+}
+func (s *apiKeyRepoStubForGroupUpdate) ExistsByKey(context.Context, string) (bool, error) {
+ panic("unexpected")
+}
+func (s *apiKeyRepoStubForGroupUpdate) ListByGroupID(context.Context, int64, pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
+ panic("unexpected")
+}
+func (s *apiKeyRepoStubForGroupUpdate) SearchAPIKeys(context.Context, int64, string, int) ([]APIKey, error) {
+ panic("unexpected")
+}
+func (s *apiKeyRepoStubForGroupUpdate) ClearGroupIDByGroupID(context.Context, int64) (int64, error) {
+ panic("unexpected")
+}
+func (s *apiKeyRepoStubForGroupUpdate) CountByGroupID(context.Context, int64) (int64, error) {
+ panic("unexpected")
+}
+func (s *apiKeyRepoStubForGroupUpdate) ListKeysByUserID(context.Context, int64) ([]string, error) {
+ panic("unexpected")
+}
+func (s *apiKeyRepoStubForGroupUpdate) ListKeysByGroupID(context.Context, int64) ([]string, error) {
+ panic("unexpected")
+}
+func (s *apiKeyRepoStubForGroupUpdate) IncrementQuotaUsed(context.Context, int64, float64) (float64, error) {
+ panic("unexpected")
+}
+func (s *apiKeyRepoStubForGroupUpdate) UpdateLastUsed(context.Context, int64, time.Time) error {
+ panic("unexpected")
+}
+
+// groupRepoStubForGroupUpdate implements GroupRepository for AdminUpdateAPIKeyGroupID tests.
+type groupRepoStubForGroupUpdate struct {
+ group *Group
+ getErr error
+ lastGetByIDArg int64
+}
+
+func (s *groupRepoStubForGroupUpdate) GetByID(_ context.Context, id int64) (*Group, error) {
+ s.lastGetByIDArg = id
+ if s.getErr != nil {
+ return nil, s.getErr
+ }
+ clone := *s.group
+ return &clone, nil
+}
+
+// Unused methods – panic on unexpected call.
+func (s *groupRepoStubForGroupUpdate) Create(context.Context, *Group) error { panic("unexpected") }
+func (s *groupRepoStubForGroupUpdate) GetByIDLite(context.Context, int64) (*Group, error) {
+ panic("unexpected")
+}
+func (s *groupRepoStubForGroupUpdate) Update(context.Context, *Group) error { panic("unexpected") }
+func (s *groupRepoStubForGroupUpdate) Delete(context.Context, int64) error { panic("unexpected") }
+func (s *groupRepoStubForGroupUpdate) DeleteCascade(context.Context, int64) ([]int64, error) {
+ panic("unexpected")
+}
+func (s *groupRepoStubForGroupUpdate) List(context.Context, pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) {
+ panic("unexpected")
+}
+func (s *groupRepoStubForGroupUpdate) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, *bool) ([]Group, *pagination.PaginationResult, error) {
+ panic("unexpected")
+}
+func (s *groupRepoStubForGroupUpdate) ListActive(context.Context) ([]Group, error) {
+ panic("unexpected")
+}
+func (s *groupRepoStubForGroupUpdate) ListActiveByPlatform(context.Context, string) ([]Group, error) {
+ panic("unexpected")
+}
+func (s *groupRepoStubForGroupUpdate) ExistsByName(context.Context, string) (bool, error) {
+ panic("unexpected")
+}
+func (s *groupRepoStubForGroupUpdate) GetAccountCount(context.Context, int64) (int64, error) {
+ panic("unexpected")
+}
+func (s *groupRepoStubForGroupUpdate) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) {
+ panic("unexpected")
+}
+func (s *groupRepoStubForGroupUpdate) GetAccountIDsByGroupIDs(context.Context, []int64) ([]int64, error) {
+ panic("unexpected")
+}
+func (s *groupRepoStubForGroupUpdate) BindAccountsToGroup(context.Context, int64, []int64) error {
+ panic("unexpected")
+}
+func (s *groupRepoStubForGroupUpdate) UpdateSortOrders(context.Context, []GroupSortOrderUpdate) error {
+ panic("unexpected")
+}
+
+// ---------------------------------------------------------------------------
+// Tests
+// ---------------------------------------------------------------------------
+
+func TestAdminService_AdminUpdateAPIKeyGroupID_KeyNotFound(t *testing.T) {
+ repo := &apiKeyRepoStubForGroupUpdate{getErr: ErrAPIKeyNotFound}
+ svc := &adminServiceImpl{apiKeyRepo: repo}
+
+ _, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 999, int64Ptr(1))
+ require.ErrorIs(t, err, ErrAPIKeyNotFound)
+}
+
+func TestAdminService_AdminUpdateAPIKeyGroupID_NilGroupID_NoOp(t *testing.T) {
+ existing := &APIKey{ID: 1, Key: "sk-test", GroupID: int64Ptr(5)}
+ repo := &apiKeyRepoStubForGroupUpdate{key: existing}
+ svc := &adminServiceImpl{apiKeyRepo: repo}
+
+ got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, nil)
+ require.NoError(t, err)
+ require.Equal(t, int64(1), got.APIKey.ID)
+ // Update should NOT have been called (updated stays nil)
+ require.Nil(t, repo.updated)
+}
+
+func TestAdminService_AdminUpdateAPIKeyGroupID_Unbind(t *testing.T) {
+ existing := &APIKey{ID: 1, Key: "sk-test", GroupID: int64Ptr(5), Group: &Group{ID: 5, Name: "Old"}}
+ repo := &apiKeyRepoStubForGroupUpdate{key: existing}
+ cache := &authCacheInvalidatorStub{}
+ svc := &adminServiceImpl{apiKeyRepo: repo, authCacheInvalidator: cache}
+
+ got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(0))
+ require.NoError(t, err)
+ require.Nil(t, got.APIKey.GroupID, "group_id should be nil after unbind")
+ require.Nil(t, got.APIKey.Group, "group object should be nil after unbind")
+ require.NotNil(t, repo.updated, "Update should have been called")
+ require.Nil(t, repo.updated.GroupID)
+ require.Equal(t, []string{"sk-test"}, cache.keys, "cache should be invalidated")
+}
+
+func TestAdminService_AdminUpdateAPIKeyGroupID_BindActiveGroup(t *testing.T) {
+ existing := &APIKey{ID: 1, Key: "sk-test", GroupID: nil}
+ apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
+ groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Pro", Status: StatusActive}}
+ cache := &authCacheInvalidatorStub{}
+ svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, authCacheInvalidator: cache}
+
+ got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10))
+ require.NoError(t, err)
+ require.NotNil(t, got.APIKey.GroupID)
+ require.Equal(t, int64(10), *got.APIKey.GroupID)
+ require.Equal(t, int64(10), *apiKeyRepo.updated.GroupID)
+ require.Equal(t, []string{"sk-test"}, cache.keys)
+ // M3: verify correct group ID was passed to repo
+ require.Equal(t, int64(10), groupRepo.lastGetByIDArg)
+ // C1 fix: verify Group object is populated
+ require.NotNil(t, got.APIKey.Group)
+ require.Equal(t, "Pro", got.APIKey.Group.Name)
+}
+
+func TestAdminService_AdminUpdateAPIKeyGroupID_SameGroup_Idempotent(t *testing.T) {
+ existing := &APIKey{ID: 1, Key: "sk-test", GroupID: int64Ptr(10), Group: &Group{ID: 10, Name: "Pro"}}
+ apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
+ groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Pro", Status: StatusActive}}
+ cache := &authCacheInvalidatorStub{}
+ svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, authCacheInvalidator: cache}
+
+ got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10))
+ require.NoError(t, err)
+ require.NotNil(t, got.APIKey.GroupID)
+ require.Equal(t, int64(10), *got.APIKey.GroupID)
+ // Update is still called (current impl doesn't short-circuit on same group)
+ require.NotNil(t, apiKeyRepo.updated)
+ require.Equal(t, []string{"sk-test"}, cache.keys)
+}
+
+func TestAdminService_AdminUpdateAPIKeyGroupID_GroupNotFound(t *testing.T) {
+ existing := &APIKey{ID: 1, Key: "sk-test"}
+ apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
+ groupRepo := &groupRepoStubForGroupUpdate{getErr: ErrGroupNotFound}
+ svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo}
+
+ _, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(99))
+ require.ErrorIs(t, err, ErrGroupNotFound)
+}
+
+func TestAdminService_AdminUpdateAPIKeyGroupID_GroupNotActive(t *testing.T) {
+ existing := &APIKey{ID: 1, Key: "sk-test"}
+ apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
+ groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 5, Status: StatusDisabled}}
+ svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo}
+
+ _, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(5))
+ require.Error(t, err)
+ require.Equal(t, "GROUP_NOT_ACTIVE", infraerrors.Reason(err))
+}
+
+func TestAdminService_AdminUpdateAPIKeyGroupID_UpdateFails(t *testing.T) {
+ existing := &APIKey{ID: 1, Key: "sk-test", GroupID: int64Ptr(3)}
+ repo := &apiKeyRepoStubForGroupUpdate{key: existing, updateErr: errors.New("db write error")}
+ svc := &adminServiceImpl{apiKeyRepo: repo}
+
+ _, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(0))
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "update api key")
+}
+
+func TestAdminService_AdminUpdateAPIKeyGroupID_NegativeGroupID(t *testing.T) {
+ existing := &APIKey{ID: 1, Key: "sk-test"}
+ apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
+ svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo}
+
+ _, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(-5))
+ require.Error(t, err)
+ require.Equal(t, "INVALID_GROUP_ID", infraerrors.Reason(err))
+}
+
+func TestAdminService_AdminUpdateAPIKeyGroupID_PointerIsolation(t *testing.T) {
+ existing := &APIKey{ID: 1, Key: "sk-test", GroupID: nil}
+ apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
+ groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Pro", Status: StatusActive}}
+ cache := &authCacheInvalidatorStub{}
+ svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, authCacheInvalidator: cache}
+
+ inputGID := int64(10)
+ got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, &inputGID)
+ require.NoError(t, err)
+ require.NotNil(t, got.APIKey.GroupID)
+ // Mutating the input pointer must NOT affect the stored value
+ inputGID = 999
+ require.Equal(t, int64(10), *got.APIKey.GroupID)
+ require.Equal(t, int64(10), *apiKeyRepo.updated.GroupID)
+}
+
+func TestAdminService_AdminUpdateAPIKeyGroupID_NilCacheInvalidator(t *testing.T) {
+ existing := &APIKey{ID: 1, Key: "sk-test"}
+ apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
+ groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 7, Status: StatusActive}}
+ // authCacheInvalidator is nil – should not panic
+ svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo}
+
+ got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(7))
+ require.NoError(t, err)
+ require.NotNil(t, got.APIKey.GroupID)
+ require.Equal(t, int64(7), *got.APIKey.GroupID)
+}
+
+// ---------------------------------------------------------------------------
+// Tests: AllowedGroup auto-sync
+// ---------------------------------------------------------------------------
+
+func TestAdminService_AdminUpdateAPIKeyGroupID_ExclusiveGroup_AddsAllowedGroup(t *testing.T) {
+ existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: nil}
+ apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
+ groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Exclusive", Status: StatusActive, IsExclusive: true, SubscriptionType: SubscriptionTypeStandard}}
+ userRepo := &userRepoStubForGroupUpdate{}
+ cache := &authCacheInvalidatorStub{}
+ svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, userRepo: userRepo, authCacheInvalidator: cache}
+
+ got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10))
+ require.NoError(t, err)
+ require.NotNil(t, got.APIKey.GroupID)
+ require.Equal(t, int64(10), *got.APIKey.GroupID)
+ // 验证 AddGroupToAllowedGroups 被调用,且参数正确
+ require.True(t, userRepo.addGroupCalled)
+ require.Equal(t, int64(42), userRepo.addedUserID)
+ require.Equal(t, int64(10), userRepo.addedGroupID)
+ // 验证 result 标记了自动授权
+ require.True(t, got.AutoGrantedGroupAccess)
+ require.NotNil(t, got.GrantedGroupID)
+ require.Equal(t, int64(10), *got.GrantedGroupID)
+ require.Equal(t, "Exclusive", got.GrantedGroupName)
+}
+
+func TestAdminService_AdminUpdateAPIKeyGroupID_NonExclusiveGroup_NoAllowedGroupUpdate(t *testing.T) {
+ existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: nil}
+ apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
+ groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Public", Status: StatusActive, IsExclusive: false, SubscriptionType: SubscriptionTypeStandard}}
+ userRepo := &userRepoStubForGroupUpdate{}
+ cache := &authCacheInvalidatorStub{}
+ svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, userRepo: userRepo, authCacheInvalidator: cache}
+
+ got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10))
+ require.NoError(t, err)
+ require.NotNil(t, got.APIKey.GroupID)
+ // 非专属分组不触发 AddGroupToAllowedGroups
+ require.False(t, userRepo.addGroupCalled)
+ require.False(t, got.AutoGrantedGroupAccess)
+}
+
+func TestAdminService_AdminUpdateAPIKeyGroupID_SubscriptionGroup_Blocked(t *testing.T) {
+ existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: nil}
+ apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
+ groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Sub", Status: StatusActive, IsExclusive: true, SubscriptionType: SubscriptionTypeSubscription}}
+ userRepo := &userRepoStubForGroupUpdate{}
+ svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, userRepo: userRepo}
+
+ // 订阅类型分组应被阻止绑定
+ _, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10))
+ require.Error(t, err)
+ require.Equal(t, "SUBSCRIPTION_GROUP_NOT_ALLOWED", infraerrors.Reason(err))
+ require.False(t, userRepo.addGroupCalled)
+}
+
+func TestAdminService_AdminUpdateAPIKeyGroupID_ExclusiveGroup_AllowedGroupAddFails_ReturnsError(t *testing.T) {
+ existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: nil}
+ apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
+ groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Exclusive", Status: StatusActive, IsExclusive: true, SubscriptionType: SubscriptionTypeStandard}}
+ userRepo := &userRepoStubForGroupUpdate{addGroupErr: errors.New("db error")}
+ svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, userRepo: userRepo}
+
+ // 严格模式:AddGroupToAllowedGroups 失败时,整体操作报错
+ _, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10))
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "add group to user allowed groups")
+ require.True(t, userRepo.addGroupCalled)
+ // apiKey 不应被更新
+ require.Nil(t, apiKeyRepo.updated)
+}
+
+func TestAdminService_AdminUpdateAPIKeyGroupID_Unbind_NoAllowedGroupUpdate(t *testing.T) {
+ existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: int64Ptr(10), Group: &Group{ID: 10, Name: "Exclusive"}}
+ apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
+ userRepo := &userRepoStubForGroupUpdate{}
+ cache := &authCacheInvalidatorStub{}
+ svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, userRepo: userRepo, authCacheInvalidator: cache}
+
+ got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(0))
+ require.NoError(t, err)
+ require.Nil(t, got.APIKey.GroupID)
+ // 解绑时不修改 allowed_groups
+ require.False(t, userRepo.addGroupCalled)
+ require.False(t, got.AutoGrantedGroupAccess)
+}
diff --git a/backend/internal/service/admin_service_bulk_update_test.go b/backend/internal/service/admin_service_bulk_update_test.go
index 0dccacbb..4845d87c 100644
--- a/backend/internal/service/admin_service_bulk_update_test.go
+++ b/backend/internal/service/admin_service_bulk_update_test.go
@@ -15,6 +15,7 @@ type accountRepoStubForBulkUpdate struct {
bulkUpdateErr error
bulkUpdateIDs []int64
bindGroupErrByID map[int64]error
+ bindGroupsCalls []int64
getByIDsAccounts []*Account
getByIDsErr error
getByIDsCalled bool
@@ -22,6 +23,8 @@ type accountRepoStubForBulkUpdate struct {
getByIDAccounts map[int64]*Account
getByIDErrByID map[int64]error
getByIDCalled []int64
+ listByGroupData map[int64][]Account
+ listByGroupErr map[int64]error
}
func (s *accountRepoStubForBulkUpdate) BulkUpdate(_ context.Context, ids []int64, _ AccountBulkUpdate) (int64, error) {
@@ -33,6 +36,7 @@ func (s *accountRepoStubForBulkUpdate) BulkUpdate(_ context.Context, ids []int64
}
func (s *accountRepoStubForBulkUpdate) BindGroups(_ context.Context, accountID int64, _ []int64) error {
+ s.bindGroupsCalls = append(s.bindGroupsCalls, accountID)
if err, ok := s.bindGroupErrByID[accountID]; ok {
return err
}
@@ -59,6 +63,16 @@ func (s *accountRepoStubForBulkUpdate) GetByID(_ context.Context, id int64) (*Ac
return nil, errors.New("account not found")
}
+func (s *accountRepoStubForBulkUpdate) ListByGroup(_ context.Context, groupID int64) ([]Account, error) {
+ if err, ok := s.listByGroupErr[groupID]; ok {
+ return nil, err
+ }
+ if rows, ok := s.listByGroupData[groupID]; ok {
+ return rows, nil
+ }
+ return nil, nil
+}
+
// TestAdminService_BulkUpdateAccounts_AllSuccessIDs 验证批量更新成功时返回 success_ids/failed_ids。
func TestAdminService_BulkUpdateAccounts_AllSuccessIDs(t *testing.T) {
repo := &accountRepoStubForBulkUpdate{}
@@ -86,7 +100,10 @@ func TestAdminService_BulkUpdateAccounts_PartialFailureIDs(t *testing.T) {
2: errors.New("bind failed"),
},
}
- svc := &adminServiceImpl{accountRepo: repo}
+ svc := &adminServiceImpl{
+ accountRepo: repo,
+ groupRepo: &groupRepoStubForAdmin{getByID: &Group{ID: 10, Name: "g10"}},
+ }
groupIDs := []int64{10}
schedulable := false
@@ -105,3 +122,51 @@ func TestAdminService_BulkUpdateAccounts_PartialFailureIDs(t *testing.T) {
require.ElementsMatch(t, []int64{2}, result.FailedIDs)
require.Len(t, result.Results, 3)
}
+
+func TestAdminService_BulkUpdateAccounts_NilGroupRepoReturnsError(t *testing.T) {
+ repo := &accountRepoStubForBulkUpdate{}
+ svc := &adminServiceImpl{accountRepo: repo}
+
+ groupIDs := []int64{10}
+ input := &BulkUpdateAccountsInput{
+ AccountIDs: []int64{1},
+ GroupIDs: &groupIDs,
+ }
+
+ result, err := svc.BulkUpdateAccounts(context.Background(), input)
+ require.Nil(t, result)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "group repository not configured")
+}
+
+// TestAdminService_BulkUpdateAccounts_MixedChannelPreCheckBlocksOnExistingConflict verifies
+// that the global pre-check detects a conflict with existing group members and returns an
+// error before any DB write is performed.
+func TestAdminService_BulkUpdateAccounts_MixedChannelPreCheckBlocksOnExistingConflict(t *testing.T) {
+ repo := &accountRepoStubForBulkUpdate{
+ getByIDsAccounts: []*Account{
+ {ID: 1, Platform: PlatformAntigravity},
+ },
+ // Group 10 already contains an Anthropic account.
+ listByGroupData: map[int64][]Account{
+ 10: {{ID: 99, Platform: PlatformAnthropic}},
+ },
+ }
+ svc := &adminServiceImpl{
+ accountRepo: repo,
+ groupRepo: &groupRepoStubForAdmin{getByID: &Group{ID: 10, Name: "target-group"}},
+ }
+
+ groupIDs := []int64{10}
+ input := &BulkUpdateAccountsInput{
+ AccountIDs: []int64{1},
+ GroupIDs: &groupIDs,
+ }
+
+ result, err := svc.BulkUpdateAccounts(context.Background(), input)
+ require.Nil(t, result)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "mixed channel")
+ // No BindGroups should have been called since the check runs before any write.
+ require.Empty(t, repo.bindGroupsCalls)
+}
diff --git a/backend/internal/service/admin_service_create_user_test.go b/backend/internal/service/admin_service_create_user_test.go
index a0fe4d87..c5b1e38d 100644
--- a/backend/internal/service/admin_service_create_user_test.go
+++ b/backend/internal/service/admin_service_create_user_test.go
@@ -7,6 +7,7 @@ import (
"errors"
"testing"
+ "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
@@ -65,3 +66,32 @@ func TestAdminService_CreateUser_CreateError(t *testing.T) {
require.ErrorIs(t, err, createErr)
require.Empty(t, repo.created)
}
+
+func TestAdminService_CreateUser_AssignsDefaultSubscriptions(t *testing.T) {
+ repo := &userRepoStub{nextID: 21}
+ assigner := &defaultSubscriptionAssignerStub{}
+ cfg := &config.Config{
+ Default: config.DefaultConfig{
+ UserBalance: 0,
+ UserConcurrency: 1,
+ },
+ }
+ settingService := NewSettingService(&settingRepoStub{values: map[string]string{
+ SettingKeyDefaultSubscriptions: `[{"group_id":5,"validity_days":30}]`,
+ }}, cfg)
+ svc := &adminServiceImpl{
+ userRepo: repo,
+ settingService: settingService,
+ defaultSubAssigner: assigner,
+ }
+
+ _, err := svc.CreateUser(context.Background(), &CreateUserInput{
+ Email: "new-user@test.com",
+ Password: "password",
+ })
+ require.NoError(t, err)
+ require.Len(t, assigner.calls, 1)
+ require.Equal(t, int64(21), assigner.calls[0].UserID)
+ require.Equal(t, int64(5), assigner.calls[0].GroupID)
+ require.Equal(t, 30, assigner.calls[0].ValidityDays)
+}
diff --git a/backend/internal/service/admin_service_delete_test.go b/backend/internal/service/admin_service_delete_test.go
index 60fa3d77..bb906df5 100644
--- a/backend/internal/service/admin_service_delete_test.go
+++ b/backend/internal/service/admin_service_delete_test.go
@@ -93,6 +93,10 @@ func (s *userRepoStub) RemoveGroupFromAllowedGroups(ctx context.Context, groupID
panic("unexpected RemoveGroupFromAllowedGroups call")
}
+func (s *userRepoStub) AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error {
+ panic("unexpected AddGroupToAllowedGroups call")
+}
+
func (s *userRepoStub) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
panic("unexpected UpdateTotpSecret call")
}
diff --git a/backend/internal/service/admin_service_list_users_test.go b/backend/internal/service/admin_service_list_users_test.go
new file mode 100644
index 00000000..8b50530a
--- /dev/null
+++ b/backend/internal/service/admin_service_list_users_test.go
@@ -0,0 +1,106 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "errors"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
+ "github.com/stretchr/testify/require"
+)
+
+type userRepoStubForListUsers struct {
+ userRepoStub
+ users []User
+ err error
+}
+
+func (s *userRepoStubForListUsers) ListWithFilters(_ context.Context, params pagination.PaginationParams, _ UserListFilters) ([]User, *pagination.PaginationResult, error) {
+ if s.err != nil {
+ return nil, nil, s.err
+ }
+ out := make([]User, len(s.users))
+ copy(out, s.users)
+ return out, &pagination.PaginationResult{
+ Total: int64(len(out)),
+ Page: params.Page,
+ PageSize: params.PageSize,
+ }, nil
+}
+
+type userGroupRateRepoStubForListUsers struct {
+ batchCalls int
+ singleCall []int64
+
+ batchErr error
+ batchData map[int64]map[int64]float64
+
+ singleErr map[int64]error
+ singleData map[int64]map[int64]float64
+}
+
+func (s *userGroupRateRepoStubForListUsers) GetByUserIDs(_ context.Context, _ []int64) (map[int64]map[int64]float64, error) {
+ s.batchCalls++
+ if s.batchErr != nil {
+ return nil, s.batchErr
+ }
+ return s.batchData, nil
+}
+
+func (s *userGroupRateRepoStubForListUsers) GetByUserID(_ context.Context, userID int64) (map[int64]float64, error) {
+ s.singleCall = append(s.singleCall, userID)
+ if err, ok := s.singleErr[userID]; ok {
+ return nil, err
+ }
+ if rates, ok := s.singleData[userID]; ok {
+ return rates, nil
+ }
+ return map[int64]float64{}, nil
+}
+
+func (s *userGroupRateRepoStubForListUsers) GetByUserAndGroup(_ context.Context, userID, groupID int64) (*float64, error) {
+ panic("unexpected GetByUserAndGroup call")
+}
+
+func (s *userGroupRateRepoStubForListUsers) SyncUserGroupRates(_ context.Context, userID int64, rates map[int64]*float64) error {
+ panic("unexpected SyncUserGroupRates call")
+}
+
+func (s *userGroupRateRepoStubForListUsers) DeleteByGroupID(_ context.Context, groupID int64) error {
+ panic("unexpected DeleteByGroupID call")
+}
+
+func (s *userGroupRateRepoStubForListUsers) DeleteByUserID(_ context.Context, userID int64) error {
+ panic("unexpected DeleteByUserID call")
+}
+
+func TestAdminService_ListUsers_BatchRateFallbackToSingle(t *testing.T) {
+ userRepo := &userRepoStubForListUsers{
+ users: []User{
+ {ID: 101, Username: "u1"},
+ {ID: 202, Username: "u2"},
+ },
+ }
+ rateRepo := &userGroupRateRepoStubForListUsers{
+ batchErr: errors.New("batch unavailable"),
+ singleData: map[int64]map[int64]float64{
+ 101: {11: 1.1},
+ 202: {22: 2.2},
+ },
+ }
+ svc := &adminServiceImpl{
+ userRepo: userRepo,
+ userGroupRateRepo: rateRepo,
+ }
+
+ users, total, err := svc.ListUsers(context.Background(), 1, 20, UserListFilters{})
+ require.NoError(t, err)
+ require.Equal(t, int64(2), total)
+ require.Len(t, users, 2)
+ require.Equal(t, 1, rateRepo.batchCalls)
+ require.ElementsMatch(t, []int64{101, 202}, rateRepo.singleCall)
+ require.Equal(t, 1.1, users[0].GroupRates[11])
+ require.Equal(t, 2.2, users[1].GroupRates[22])
+}
diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go
index 108ff9ab..96ff3354 100644
--- a/backend/internal/service/antigravity_gateway_service.go
+++ b/backend/internal/service/antigravity_gateway_service.go
@@ -21,7 +21,6 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
- "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
@@ -2291,7 +2290,7 @@ func sleepAntigravityBackoffWithContext(ctx context.Context, attempt int) bool {
// isSingleAccountRetry 检查 context 中是否设置了单账号退避重试标记
func isSingleAccountRetry(ctx context.Context) bool {
- v, _ := ctx.Value(ctxkey.SingleAccountRetry).(bool)
+ v, _ := SingleAccountRetryFromContext(ctx)
return v
}
@@ -3757,14 +3756,17 @@ func (s *AntigravityGatewayService) extractImageSize(body []byte) string {
}
// isImageGenerationModel 判断模型是否为图片生成模型
-// 支持的模型:gemini-3-pro-image, gemini-3-pro-image-preview, gemini-2.5-flash-image 等
+// 支持的模型:gemini-3.1-flash-image, gemini-3-pro-image, gemini-2.5-flash-image 等
func isImageGenerationModel(model string) bool {
modelLower := strings.ToLower(model)
// 移除 models/ 前缀
modelLower = strings.TrimPrefix(modelLower, "models/")
// 精确匹配或前缀匹配
- return modelLower == "gemini-3-pro-image" ||
+ return modelLower == "gemini-3.1-flash-image" ||
+ modelLower == "gemini-3.1-flash-image-preview" ||
+ strings.HasPrefix(modelLower, "gemini-3.1-flash-image-") ||
+ modelLower == "gemini-3-pro-image" ||
modelLower == "gemini-3-pro-image-preview" ||
strings.HasPrefix(modelLower, "gemini-3-pro-image-") ||
modelLower == "gemini-2.5-flash-image" ||
diff --git a/backend/internal/service/api_key.go b/backend/internal/service/api_key.go
index fe1b3a5d..07523597 100644
--- a/backend/internal/service/api_key.go
+++ b/backend/internal/service/api_key.go
@@ -1,6 +1,10 @@
package service
-import "time"
+import (
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/pkg/ip"
+)
// API Key status constants
const (
@@ -19,11 +23,14 @@ type APIKey struct {
Status string
IPWhitelist []string
IPBlacklist []string
- LastUsedAt *time.Time
- CreatedAt time.Time
- UpdatedAt time.Time
- User *User
- Group *Group
+ // 预编译的 IP 规则,用于认证热路径避免重复 ParseIP/ParseCIDR。
+ CompiledIPWhitelist *ip.CompiledIPRules `json:"-"`
+ CompiledIPBlacklist *ip.CompiledIPRules `json:"-"`
+ LastUsedAt *time.Time
+ CreatedAt time.Time
+ UpdatedAt time.Time
+ User *User
+ Group *Group
// Quota fields
Quota float64 // Quota limit in USD (0 = unlimited)
diff --git a/backend/internal/service/api_key_auth_cache_impl.go b/backend/internal/service/api_key_auth_cache_impl.go
index 77a75674..30eb8d74 100644
--- a/backend/internal/service/api_key_auth_cache_impl.go
+++ b/backend/internal/service/api_key_auth_cache_impl.go
@@ -298,5 +298,6 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
SupportedModelScopes: snapshot.Group.SupportedModelScopes,
}
}
+ s.compileAPIKeyIPRules(apiKey)
return apiKey
}
diff --git a/backend/internal/service/api_key_service.go b/backend/internal/service/api_key_service.go
index c5e1cfab..0d073077 100644
--- a/backend/internal/service/api_key_service.go
+++ b/backend/internal/service/api_key_service.go
@@ -158,6 +158,14 @@ func NewAPIKeyService(
return svc
}
+func (s *APIKeyService) compileAPIKeyIPRules(apiKey *APIKey) {
+ if apiKey == nil {
+ return
+ }
+ apiKey.CompiledIPWhitelist = ip.CompileIPRules(apiKey.IPWhitelist)
+ apiKey.CompiledIPBlacklist = ip.CompileIPRules(apiKey.IPBlacklist)
+}
+
// GenerateKey 生成随机API Key
func (s *APIKeyService) GenerateKey() (string, error) {
// 生成32字节随机数据
@@ -332,6 +340,7 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK
}
s.InvalidateAuthCacheByKey(ctx, apiKey.Key)
+ s.compileAPIKeyIPRules(apiKey)
return apiKey, nil
}
@@ -363,6 +372,7 @@ func (s *APIKeyService) GetByID(ctx context.Context, id int64) (*APIKey, error)
if err != nil {
return nil, fmt.Errorf("get api key: %w", err)
}
+ s.compileAPIKeyIPRules(apiKey)
return apiKey, nil
}
@@ -375,6 +385,7 @@ func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, erro
if err != nil {
return nil, fmt.Errorf("get api key: %w", err)
}
+ s.compileAPIKeyIPRules(apiKey)
return apiKey, nil
}
}
@@ -391,6 +402,7 @@ func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, erro
if err != nil {
return nil, fmt.Errorf("get api key: %w", err)
}
+ s.compileAPIKeyIPRules(apiKey)
return apiKey, nil
}
} else {
@@ -402,6 +414,7 @@ func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, erro
if err != nil {
return nil, fmt.Errorf("get api key: %w", err)
}
+ s.compileAPIKeyIPRules(apiKey)
return apiKey, nil
}
}
@@ -411,6 +424,7 @@ func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, erro
return nil, fmt.Errorf("get api key: %w", err)
}
apiKey.Key = key
+ s.compileAPIKeyIPRules(apiKey)
return apiKey, nil
}
@@ -510,6 +524,7 @@ func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req
}
s.InvalidateAuthCacheByKey(ctx, apiKey.Key)
+ s.compileAPIKeyIPRules(apiKey)
return apiKey, nil
}
diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go
index 663e1215..9df61c44 100644
--- a/backend/internal/service/auth_service.go
+++ b/backend/internal/service/auth_service.go
@@ -56,15 +56,20 @@ type JWTClaims struct {
// AuthService 认证服务
type AuthService struct {
- userRepo UserRepository
- redeemRepo RedeemCodeRepository
- refreshTokenCache RefreshTokenCache
- cfg *config.Config
- settingService *SettingService
- emailService *EmailService
- turnstileService *TurnstileService
- emailQueueService *EmailQueueService
- promoService *PromoService
+ userRepo UserRepository
+ redeemRepo RedeemCodeRepository
+ refreshTokenCache RefreshTokenCache
+ cfg *config.Config
+ settingService *SettingService
+ emailService *EmailService
+ turnstileService *TurnstileService
+ emailQueueService *EmailQueueService
+ promoService *PromoService
+ defaultSubAssigner DefaultSubscriptionAssigner
+}
+
+type DefaultSubscriptionAssigner interface {
+ AssignOrExtendSubscription(ctx context.Context, input *AssignSubscriptionInput) (*UserSubscription, bool, error)
}
// NewAuthService 创建认证服务实例
@@ -78,17 +83,19 @@ func NewAuthService(
turnstileService *TurnstileService,
emailQueueService *EmailQueueService,
promoService *PromoService,
+ defaultSubAssigner DefaultSubscriptionAssigner,
) *AuthService {
return &AuthService{
- userRepo: userRepo,
- redeemRepo: redeemRepo,
- refreshTokenCache: refreshTokenCache,
- cfg: cfg,
- settingService: settingService,
- emailService: emailService,
- turnstileService: turnstileService,
- emailQueueService: emailQueueService,
- promoService: promoService,
+ userRepo: userRepo,
+ redeemRepo: redeemRepo,
+ refreshTokenCache: refreshTokenCache,
+ cfg: cfg,
+ settingService: settingService,
+ emailService: emailService,
+ turnstileService: turnstileService,
+ emailQueueService: emailQueueService,
+ promoService: promoService,
+ defaultSubAssigner: defaultSubAssigner,
}
}
@@ -188,6 +195,7 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
logger.LegacyPrintf("service.auth", "[Auth] Database error creating user: %v", err)
return "", nil, ErrServiceUnavailable
}
+ s.assignDefaultSubscriptions(ctx, user.ID)
// 标记邀请码为已使用(如果使用了邀请码)
if invitationRedeemCode != nil {
@@ -308,6 +316,17 @@ func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*S
}, nil
}
+// VerifyTurnstileForRegister 在注册场景下验证 Turnstile。
+// 当邮箱验证开启且已提交验证码时,说明验证码发送阶段已完成 Turnstile 校验,
+// 此处跳过二次校验,避免一次性 token 在注册提交时重复使用导致误报失败。
+func (s *AuthService) VerifyTurnstileForRegister(ctx context.Context, token, remoteIP, verifyCode string) error {
+ if s.IsEmailVerifyEnabled(ctx) && strings.TrimSpace(verifyCode) != "" {
+ logger.LegacyPrintf("service.auth", "%s", "[Auth] Email verify flow detected, skip duplicate Turnstile check on register")
+ return nil
+ }
+ return s.VerifyTurnstile(ctx, token, remoteIP)
+}
+
// VerifyTurnstile 验证Turnstile token
func (s *AuthService) VerifyTurnstile(ctx context.Context, token string, remoteIP string) error {
required := s.cfg != nil && s.cfg.Server.Mode == "release" && s.cfg.Turnstile.Required
@@ -466,6 +485,7 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
}
} else {
user = newUser
+ s.assignDefaultSubscriptions(ctx, user.ID)
}
} else {
logger.LegacyPrintf("service.auth", "[Auth] Database error during oauth login: %v", err)
@@ -561,6 +581,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
}
} else {
user = newUser
+ s.assignDefaultSubscriptions(ctx, user.ID)
}
} else {
logger.LegacyPrintf("service.auth", "[Auth] Database error during oauth login: %v", err)
@@ -586,6 +607,23 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
return tokenPair, user, nil
}
+func (s *AuthService) assignDefaultSubscriptions(ctx context.Context, userID int64) {
+ if s.settingService == nil || s.defaultSubAssigner == nil || userID <= 0 {
+ return
+ }
+ items := s.settingService.GetDefaultSubscriptions(ctx)
+ for _, item := range items {
+ if _, _, err := s.defaultSubAssigner.AssignOrExtendSubscription(ctx, &AssignSubscriptionInput{
+ UserID: userID,
+ GroupID: item.GroupID,
+ ValidityDays: item.ValidityDays,
+ Notes: "auto assigned by default user subscriptions setting",
+ }); err != nil {
+ logger.LegacyPrintf("service.auth", "[Auth] Failed to assign default subscription: user_id=%d group_id=%d err=%v", userID, item.GroupID, err)
+ }
+ }
+}
+
// ValidateToken 验证JWT token并返回用户声明
func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) {
// 先做长度校验,尽早拒绝异常超长 token,降低 DoS 风险。
diff --git a/backend/internal/service/auth_service_register_test.go b/backend/internal/service/auth_service_register_test.go
index 93659743..1999e759 100644
--- a/backend/internal/service/auth_service_register_test.go
+++ b/backend/internal/service/auth_service_register_test.go
@@ -56,6 +56,21 @@ type emailCacheStub struct {
err error
}
+type defaultSubscriptionAssignerStub struct {
+ calls []AssignSubscriptionInput
+ err error
+}
+
+func (s *defaultSubscriptionAssignerStub) AssignOrExtendSubscription(_ context.Context, input *AssignSubscriptionInput) (*UserSubscription, bool, error) {
+ if input != nil {
+ s.calls = append(s.calls, *input)
+ }
+ if s.err != nil {
+ return nil, false, s.err
+ }
+ return &UserSubscription{UserID: input.UserID, GroupID: input.GroupID}, false, nil
+}
+
func (s *emailCacheStub) GetVerificationCode(ctx context.Context, email string) (*VerificationCodeData, error) {
if s.err != nil {
return nil, s.err
@@ -123,6 +138,7 @@ func newAuthService(repo *userRepoStub, settings map[string]string, emailCache E
nil,
nil,
nil, // promoService
+ nil, // defaultSubAssigner
)
}
@@ -381,3 +397,23 @@ func TestAuthService_GenerateToken_UsesMinutesWhenConfigured(t *testing.T) {
require.WithinDuration(t, claims.IssuedAt.Time.Add(90*time.Minute), claims.ExpiresAt.Time, 2*time.Second)
}
+
+func TestAuthService_Register_AssignsDefaultSubscriptions(t *testing.T) {
+ repo := &userRepoStub{nextID: 42}
+ assigner := &defaultSubscriptionAssignerStub{}
+ service := newAuthService(repo, map[string]string{
+ SettingKeyRegistrationEnabled: "true",
+ SettingKeyDefaultSubscriptions: `[{"group_id":11,"validity_days":30},{"group_id":12,"validity_days":7}]`,
+ }, nil)
+ service.defaultSubAssigner = assigner
+
+ _, user, err := service.Register(context.Background(), "default-sub@test.com", "password")
+ require.NoError(t, err)
+ require.NotNil(t, user)
+ require.Len(t, assigner.calls, 2)
+ require.Equal(t, int64(42), assigner.calls[0].UserID)
+ require.Equal(t, int64(11), assigner.calls[0].GroupID)
+ require.Equal(t, 30, assigner.calls[0].ValidityDays)
+ require.Equal(t, int64(12), assigner.calls[1].GroupID)
+ require.Equal(t, 7, assigner.calls[1].ValidityDays)
+}
diff --git a/backend/internal/service/auth_service_turnstile_register_test.go b/backend/internal/service/auth_service_turnstile_register_test.go
new file mode 100644
index 00000000..36cb1e06
--- /dev/null
+++ b/backend/internal/service/auth_service_turnstile_register_test.go
@@ -0,0 +1,97 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/stretchr/testify/require"
+)
+
+type turnstileVerifierSpy struct {
+ called int
+ lastToken string
+ result *TurnstileVerifyResponse
+ err error
+}
+
+func (s *turnstileVerifierSpy) VerifyToken(_ context.Context, _ string, token, _ string) (*TurnstileVerifyResponse, error) {
+ s.called++
+ s.lastToken = token
+ if s.err != nil {
+ return nil, s.err
+ }
+ if s.result != nil {
+ return s.result, nil
+ }
+ return &TurnstileVerifyResponse{Success: true}, nil
+}
+
+func newAuthServiceForRegisterTurnstileTest(settings map[string]string, verifier TurnstileVerifier) *AuthService {
+ cfg := &config.Config{
+ Server: config.ServerConfig{
+ Mode: "release",
+ },
+ Turnstile: config.TurnstileConfig{
+ Required: true,
+ },
+ }
+
+ settingService := NewSettingService(&settingRepoStub{values: settings}, cfg)
+ turnstileService := NewTurnstileService(settingService, verifier)
+
+ return NewAuthService(
+ &userRepoStub{},
+ nil, // redeemRepo
+ nil, // refreshTokenCache
+ cfg,
+ settingService,
+ nil, // emailService
+ turnstileService,
+ nil, // emailQueueService
+ nil, // promoService
+ nil, // defaultSubAssigner
+ )
+}
+
+func TestAuthService_VerifyTurnstileForRegister_SkipWhenEmailVerifyCodeProvided(t *testing.T) {
+ verifier := &turnstileVerifierSpy{}
+ service := newAuthServiceForRegisterTurnstileTest(map[string]string{
+ SettingKeyEmailVerifyEnabled: "true",
+ SettingKeyTurnstileEnabled: "true",
+ SettingKeyTurnstileSecretKey: "secret",
+ SettingKeyRegistrationEnabled: "true",
+ }, verifier)
+
+ err := service.VerifyTurnstileForRegister(context.Background(), "", "127.0.0.1", "123456")
+ require.NoError(t, err)
+ require.Equal(t, 0, verifier.called)
+}
+
+func TestAuthService_VerifyTurnstileForRegister_RequireWhenVerifyCodeMissing(t *testing.T) {
+ verifier := &turnstileVerifierSpy{}
+ service := newAuthServiceForRegisterTurnstileTest(map[string]string{
+ SettingKeyEmailVerifyEnabled: "true",
+ SettingKeyTurnstileEnabled: "true",
+ SettingKeyTurnstileSecretKey: "secret",
+ }, verifier)
+
+ err := service.VerifyTurnstileForRegister(context.Background(), "", "127.0.0.1", "")
+ require.ErrorIs(t, err, ErrTurnstileVerificationFailed)
+}
+
+func TestAuthService_VerifyTurnstileForRegister_NoSkipWhenEmailVerifyDisabled(t *testing.T) {
+ verifier := &turnstileVerifierSpy{}
+ service := newAuthServiceForRegisterTurnstileTest(map[string]string{
+ SettingKeyEmailVerifyEnabled: "false",
+ SettingKeyTurnstileEnabled: "true",
+ SettingKeyTurnstileSecretKey: "secret",
+ }, verifier)
+
+ err := service.VerifyTurnstileForRegister(context.Background(), "turnstile-token", "127.0.0.1", "123456")
+ require.NoError(t, err)
+ require.Equal(t, 1, verifier.called)
+ require.Equal(t, "turnstile-token", verifier.lastToken)
+}
diff --git a/backend/internal/service/billing_cache_service.go b/backend/internal/service/billing_cache_service.go
index a560930b..1a76f5f6 100644
--- a/backend/internal/service/billing_cache_service.go
+++ b/backend/internal/service/billing_cache_service.go
@@ -3,6 +3,7 @@ package service
import (
"context"
"fmt"
+ "strconv"
"sync"
"sync/atomic"
"time"
@@ -10,6 +11,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
+ "golang.org/x/sync/singleflight"
)
// 错误定义
@@ -58,6 +60,7 @@ const (
cacheWriteBufferSize = 1000 // 任务队列缓冲大小
cacheWriteTimeout = 2 * time.Second // 单个写入操作超时
cacheWriteDropLogInterval = 5 * time.Second // 丢弃日志节流间隔
+ balanceLoadTimeout = 3 * time.Second
)
// cacheWriteTask 缓存写入任务
@@ -82,6 +85,9 @@ type BillingCacheService struct {
cacheWriteChan chan cacheWriteTask
cacheWriteWg sync.WaitGroup
cacheWriteStopOnce sync.Once
+ cacheWriteMu sync.RWMutex
+ stopped atomic.Bool
+ balanceLoadSF singleflight.Group
// 丢弃日志节流计数器(减少高负载下日志噪音)
cacheWriteDropFullCount uint64
cacheWriteDropFullLastLog int64
@@ -105,35 +111,52 @@ func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo
// Stop 关闭缓存写入工作池
func (s *BillingCacheService) Stop() {
s.cacheWriteStopOnce.Do(func() {
- if s.cacheWriteChan == nil {
+ s.stopped.Store(true)
+
+ s.cacheWriteMu.Lock()
+ ch := s.cacheWriteChan
+ if ch != nil {
+ close(ch)
+ }
+ s.cacheWriteMu.Unlock()
+
+ if ch == nil {
return
}
- close(s.cacheWriteChan)
s.cacheWriteWg.Wait()
- s.cacheWriteChan = nil
+
+ s.cacheWriteMu.Lock()
+ if s.cacheWriteChan == ch {
+ s.cacheWriteChan = nil
+ }
+ s.cacheWriteMu.Unlock()
})
}
func (s *BillingCacheService) startCacheWriteWorkers() {
- s.cacheWriteChan = make(chan cacheWriteTask, cacheWriteBufferSize)
+ ch := make(chan cacheWriteTask, cacheWriteBufferSize)
+ s.cacheWriteChan = ch
for i := 0; i < cacheWriteWorkerCount; i++ {
s.cacheWriteWg.Add(1)
- go s.cacheWriteWorker()
+ go s.cacheWriteWorker(ch)
}
}
// enqueueCacheWrite 尝试将任务入队,队列满时返回 false(并记录告警)。
func (s *BillingCacheService) enqueueCacheWrite(task cacheWriteTask) (enqueued bool) {
- if s.cacheWriteChan == nil {
+ if s.stopped.Load() {
+ s.logCacheWriteDrop(task, "closed")
return false
}
- defer func() {
- if recovered := recover(); recovered != nil {
- // 队列已关闭时可能触发 panic,记录后静默失败。
- s.logCacheWriteDrop(task, "closed")
- enqueued = false
- }
- }()
+
+ s.cacheWriteMu.RLock()
+ defer s.cacheWriteMu.RUnlock()
+
+ if s.cacheWriteChan == nil {
+ s.logCacheWriteDrop(task, "closed")
+ return false
+ }
+
select {
case s.cacheWriteChan <- task:
return true
@@ -144,9 +167,9 @@ func (s *BillingCacheService) enqueueCacheWrite(task cacheWriteTask) (enqueued b
}
}
-func (s *BillingCacheService) cacheWriteWorker() {
+func (s *BillingCacheService) cacheWriteWorker(ch <-chan cacheWriteTask) {
defer s.cacheWriteWg.Done()
- for task := range s.cacheWriteChan {
+ for task := range ch {
ctx, cancel := context.WithTimeout(context.Background(), cacheWriteTimeout)
switch task.kind {
case cacheWriteSetBalance:
@@ -243,19 +266,31 @@ func (s *BillingCacheService) GetUserBalance(ctx context.Context, userID int64)
return balance, nil
}
- // 缓存未命中,从数据库读取
- balance, err = s.getUserBalanceFromDB(ctx, userID)
+ // 缓存未命中:singleflight 合并同一 userID 的并发回源请求。
+ value, err, _ := s.balanceLoadSF.Do(strconv.FormatInt(userID, 10), func() (any, error) {
+ loadCtx, cancel := context.WithTimeout(context.Background(), balanceLoadTimeout)
+ defer cancel()
+
+ balance, err := s.getUserBalanceFromDB(loadCtx, userID)
+ if err != nil {
+ return nil, err
+ }
+
+ // 异步建立缓存
+ _ = s.enqueueCacheWrite(cacheWriteTask{
+ kind: cacheWriteSetBalance,
+ userID: userID,
+ balance: balance,
+ })
+ return balance, nil
+ })
if err != nil {
return 0, err
}
-
- // 异步建立缓存
- _ = s.enqueueCacheWrite(cacheWriteTask{
- kind: cacheWriteSetBalance,
- userID: userID,
- balance: balance,
- })
-
+ balance, ok := value.(float64)
+ if !ok {
+ return 0, fmt.Errorf("unexpected balance type: %T", value)
+ }
return balance, nil
}
diff --git a/backend/internal/service/billing_cache_service_singleflight_test.go b/backend/internal/service/billing_cache_service_singleflight_test.go
new file mode 100644
index 00000000..1b12c402
--- /dev/null
+++ b/backend/internal/service/billing_cache_service_singleflight_test.go
@@ -0,0 +1,115 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "errors"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/stretchr/testify/require"
+)
+
+type billingCacheMissStub struct {
+ setBalanceCalls atomic.Int64
+}
+
+func (s *billingCacheMissStub) GetUserBalance(ctx context.Context, userID int64) (float64, error) {
+ return 0, errors.New("cache miss")
+}
+
+func (s *billingCacheMissStub) SetUserBalance(ctx context.Context, userID int64, balance float64) error {
+ s.setBalanceCalls.Add(1)
+ return nil
+}
+
+func (s *billingCacheMissStub) DeductUserBalance(ctx context.Context, userID int64, amount float64) error {
+ return nil
+}
+
+func (s *billingCacheMissStub) InvalidateUserBalance(ctx context.Context, userID int64) error {
+ return nil
+}
+
+func (s *billingCacheMissStub) GetSubscriptionCache(ctx context.Context, userID, groupID int64) (*SubscriptionCacheData, error) {
+ return nil, errors.New("cache miss")
+}
+
+func (s *billingCacheMissStub) SetSubscriptionCache(ctx context.Context, userID, groupID int64, data *SubscriptionCacheData) error {
+ return nil
+}
+
+func (s *billingCacheMissStub) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error {
+ return nil
+}
+
+func (s *billingCacheMissStub) InvalidateSubscriptionCache(ctx context.Context, userID, groupID int64) error {
+ return nil
+}
+
+type balanceLoadUserRepoStub struct {
+ mockUserRepo
+ calls atomic.Int64
+ delay time.Duration
+ balance float64
+}
+
+func (s *balanceLoadUserRepoStub) GetByID(ctx context.Context, id int64) (*User, error) {
+ s.calls.Add(1)
+ if s.delay > 0 {
+ select {
+ case <-time.After(s.delay):
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ }
+ }
+ return &User{ID: id, Balance: s.balance}, nil
+}
+
+func TestBillingCacheServiceGetUserBalance_Singleflight(t *testing.T) {
+ cache := &billingCacheMissStub{}
+ userRepo := &balanceLoadUserRepoStub{
+ delay: 80 * time.Millisecond,
+ balance: 12.34,
+ }
+ svc := NewBillingCacheService(cache, userRepo, nil, &config.Config{})
+ t.Cleanup(svc.Stop)
+
+ const goroutines = 16
+ start := make(chan struct{})
+ var wg sync.WaitGroup
+ errCh := make(chan error, goroutines)
+ balCh := make(chan float64, goroutines)
+
+ for i := 0; i < goroutines; i++ {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ <-start
+ bal, err := svc.GetUserBalance(context.Background(), 99)
+ errCh <- err
+ balCh <- bal
+ }()
+ }
+
+ close(start)
+ wg.Wait()
+ close(errCh)
+ close(balCh)
+
+ for err := range errCh {
+ require.NoError(t, err)
+ }
+ for bal := range balCh {
+ require.Equal(t, 12.34, bal)
+ }
+
+ require.Equal(t, int64(1), userRepo.calls.Load(), "并发穿透应被 singleflight 合并")
+ require.Eventually(t, func() bool {
+ return cache.setBalanceCalls.Load() >= 1
+ }, time.Second, 10*time.Millisecond)
+}
diff --git a/backend/internal/service/billing_cache_service_test.go b/backend/internal/service/billing_cache_service_test.go
index 445d5319..4e5f50e2 100644
--- a/backend/internal/service/billing_cache_service_test.go
+++ b/backend/internal/service/billing_cache_service_test.go
@@ -73,3 +73,16 @@ func TestBillingCacheServiceQueueHighLoad(t *testing.T) {
return atomic.LoadInt64(&cache.subscriptionUpdates) > 0
}, 2*time.Second, 10*time.Millisecond)
}
+
+func TestBillingCacheServiceEnqueueAfterStopReturnsFalse(t *testing.T) {
+ cache := &billingCacheWorkerStub{}
+ svc := NewBillingCacheService(cache, nil, nil, &config.Config{})
+ svc.Stop()
+
+ enqueued := svc.enqueueCacheWrite(cacheWriteTask{
+ kind: cacheWriteDeductBalance,
+ userID: 1,
+ amount: 1,
+ })
+ require.False(t, enqueued)
+}
diff --git a/backend/internal/service/billing_service.go b/backend/internal/service/billing_service.go
index a523001c..6abd1e53 100644
--- a/backend/internal/service/billing_service.go
+++ b/backend/internal/service/billing_service.go
@@ -543,7 +543,10 @@ func (s *BillingService) getDefaultImagePrice(model string, imageSize string) fl
basePrice = 0.134
}
- // 4K 尺寸翻倍
+ // 2K 尺寸 1.5 倍,4K 尺寸翻倍
+ if imageSize == "2K" {
+ return basePrice * 1.5
+ }
if imageSize == "4K" {
return basePrice * 2
}
diff --git a/backend/internal/service/billing_service_image_test.go b/backend/internal/service/billing_service_image_test.go
index 18a6b74d..fa90f6bb 100644
--- a/backend/internal/service/billing_service_image_test.go
+++ b/backend/internal/service/billing_service_image_test.go
@@ -12,14 +12,14 @@ import (
func TestCalculateImageCost_DefaultPricing(t *testing.T) {
svc := &BillingService{} // pricingService 为 nil,使用硬编码默认值
- // 2K 尺寸,默认价格 $0.134
+ // 2K 尺寸,默认价格 $0.134 * 1.5 = $0.201
cost := svc.CalculateImageCost("gemini-3-pro-image", "2K", 1, nil, 1.0)
- require.InDelta(t, 0.134, cost.TotalCost, 0.0001)
- require.InDelta(t, 0.134, cost.ActualCost, 0.0001)
+ require.InDelta(t, 0.201, cost.TotalCost, 0.0001)
+ require.InDelta(t, 0.201, cost.ActualCost, 0.0001)
// 多张图片
cost = svc.CalculateImageCost("gemini-3-pro-image", "2K", 3, nil, 1.0)
- require.InDelta(t, 0.402, cost.TotalCost, 0.0001)
+ require.InDelta(t, 0.603, cost.TotalCost, 0.0001)
}
// TestCalculateImageCost_GroupCustomPricing 测试分组自定义价格
@@ -63,13 +63,13 @@ func TestCalculateImageCost_RateMultiplier(t *testing.T) {
// 费率倍数 1.5x
cost := svc.CalculateImageCost("gemini-3-pro-image", "2K", 1, nil, 1.5)
- require.InDelta(t, 0.134, cost.TotalCost, 0.0001) // TotalCost 不变
- require.InDelta(t, 0.201, cost.ActualCost, 0.0001) // ActualCost = 0.134 * 1.5
+ require.InDelta(t, 0.201, cost.TotalCost, 0.0001) // TotalCost = 0.134 * 1.5
+ require.InDelta(t, 0.3015, cost.ActualCost, 0.0001) // ActualCost = 0.201 * 1.5
// 费率倍数 2.0x
cost = svc.CalculateImageCost("gemini-3-pro-image", "2K", 2, nil, 2.0)
- require.InDelta(t, 0.268, cost.TotalCost, 0.0001)
- require.InDelta(t, 0.536, cost.ActualCost, 0.0001)
+ require.InDelta(t, 0.402, cost.TotalCost, 0.0001)
+ require.InDelta(t, 0.804, cost.ActualCost, 0.0001)
}
// TestCalculateImageCost_ZeroCount 测试 imageCount=0
@@ -95,8 +95,8 @@ func TestCalculateImageCost_ZeroRateMultiplier(t *testing.T) {
svc := &BillingService{}
cost := svc.CalculateImageCost("gemini-3-pro-image", "2K", 1, nil, 0)
- require.InDelta(t, 0.134, cost.TotalCost, 0.0001)
- require.InDelta(t, 0.134, cost.ActualCost, 0.0001) // 0 倍率当作 1.0 处理
+ require.InDelta(t, 0.201, cost.TotalCost, 0.0001)
+ require.InDelta(t, 0.201, cost.ActualCost, 0.0001) // 0 倍率当作 1.0 处理
}
// TestGetImageUnitPrice_GroupPriorityOverDefault 测试分组价格优先于默认价格
@@ -127,9 +127,9 @@ func TestGetImageUnitPrice_PartialGroupConfig(t *testing.T) {
cost := svc.CalculateImageCost("gemini-3-pro-image", "1K", 1, groupConfig, 1.0)
require.InDelta(t, 0.10, cost.TotalCost, 0.0001)
- // 2K 回退默认价格 $0.134
+ // 2K 回退默认价格 $0.201 (1.5倍)
cost = svc.CalculateImageCost("gemini-3-pro-image", "2K", 1, groupConfig, 1.0)
- require.InDelta(t, 0.134, cost.TotalCost, 0.0001)
+ require.InDelta(t, 0.201, cost.TotalCost, 0.0001)
// 4K 回退默认价格 $0.268 (翻倍)
cost = svc.CalculateImageCost("gemini-3-pro-image", "4K", 1, groupConfig, 1.0)
@@ -140,10 +140,10 @@ func TestGetImageUnitPrice_PartialGroupConfig(t *testing.T) {
func TestGetDefaultImagePrice_FallbackHardcoded(t *testing.T) {
svc := &BillingService{} // pricingService 为 nil
- // 1K 和 2K 使用相同的默认价格 $0.134
+ // 1K 默认价格 $0.134,2K 默认价格 $0.201 (1.5倍)
cost := svc.CalculateImageCost("gemini-3-pro-image", "1K", 1, nil, 1.0)
require.InDelta(t, 0.134, cost.TotalCost, 0.0001)
cost = svc.CalculateImageCost("gemini-3-pro-image", "2K", 1, nil, 1.0)
- require.InDelta(t, 0.134, cost.TotalCost, 0.0001)
+ require.InDelta(t, 0.201, cost.TotalCost, 0.0001)
}
diff --git a/backend/internal/service/claude_code_validator.go b/backend/internal/service/claude_code_validator.go
index 6d06c83e..f71098b1 100644
--- a/backend/internal/service/claude_code_validator.go
+++ b/backend/internal/service/claude_code_validator.go
@@ -4,6 +4,7 @@ import (
"context"
"net/http"
"regexp"
+ "strconv"
"strings"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
@@ -17,6 +18,9 @@ var (
// User-Agent 匹配: claude-cli/x.x.x (仅支持官方 CLI,大小写不敏感)
claudeCodeUAPattern = regexp.MustCompile(`(?i)^claude-cli/\d+\.\d+\.\d+`)
+ // 带捕获组的版本提取正则
+ claudeCodeUAVersionPattern = regexp.MustCompile(`(?i)^claude-cli/(\d+\.\d+\.\d+)`)
+
// metadata.user_id 格式: user_{64位hex}_account__session_{uuid}
userIDPattern = regexp.MustCompile(`^user_[a-fA-F0-9]{64}_account__session_[\w-]+$`)
@@ -78,7 +82,7 @@ func (v *ClaudeCodeValidator) Validate(r *http.Request, body map[string]any) boo
// Step 3: 检查 max_tokens=1 + haiku 探测请求绕过
// 这类请求用于 Claude Code 验证 API 连通性,不携带 system prompt
- if isMaxTokensOneHaiku, ok := r.Context().Value(ctxkey.IsMaxTokensOneHaikuRequest).(bool); ok && isMaxTokensOneHaiku {
+ if isMaxTokensOneHaiku, ok := IsMaxTokensOneHaikuRequestFromContext(r.Context()); ok && isMaxTokensOneHaiku {
return true // 绕过 system prompt 检查,UA 已在 Step 1 验证
}
@@ -270,3 +274,55 @@ func IsClaudeCodeClient(ctx context.Context) bool {
func SetClaudeCodeClient(ctx context.Context, isClaudeCode bool) context.Context {
return context.WithValue(ctx, ctxkey.IsClaudeCodeClient, isClaudeCode)
}
+
+// ExtractVersion 从 User-Agent 中提取 Claude Code 版本号
+// 返回 "2.1.22" 形式的版本号,如果不匹配返回空字符串
+func (v *ClaudeCodeValidator) ExtractVersion(ua string) string {
+ matches := claudeCodeUAVersionPattern.FindStringSubmatch(ua)
+ if len(matches) >= 2 {
+ return matches[1]
+ }
+ return ""
+}
+
+// SetClaudeCodeVersion 将 Claude Code 版本号设置到 context 中
+func SetClaudeCodeVersion(ctx context.Context, version string) context.Context {
+ return context.WithValue(ctx, ctxkey.ClaudeCodeVersion, version)
+}
+
+// GetClaudeCodeVersion 从 context 中获取 Claude Code 版本号
+func GetClaudeCodeVersion(ctx context.Context) string {
+ if v, ok := ctx.Value(ctxkey.ClaudeCodeVersion).(string); ok {
+ return v
+ }
+ return ""
+}
+
+// CompareVersions 比较两个 semver 版本号
+// 返回: -1 (a < b), 0 (a == b), 1 (a > b)
+func CompareVersions(a, b string) int {
+ aParts := parseSemver(a)
+ bParts := parseSemver(b)
+ for i := 0; i < 3; i++ {
+ if aParts[i] < bParts[i] {
+ return -1
+ }
+ if aParts[i] > bParts[i] {
+ return 1
+ }
+ }
+ return 0
+}
+
+// parseSemver 解析 semver 版本号为 [major, minor, patch]
+func parseSemver(v string) [3]int {
+ v = strings.TrimPrefix(v, "v")
+ parts := strings.Split(v, ".")
+ result := [3]int{0, 0, 0}
+ for i := 0; i < len(parts) && i < 3; i++ {
+ if parsed, err := strconv.Atoi(parts[i]); err == nil {
+ result[i] = parsed
+ }
+ }
+ return result
+}
diff --git a/backend/internal/service/claude_code_validator_test.go b/backend/internal/service/claude_code_validator_test.go
index a4cd1886..f87c56e8 100644
--- a/backend/internal/service/claude_code_validator_test.go
+++ b/backend/internal/service/claude_code_validator_test.go
@@ -56,3 +56,51 @@ func TestClaudeCodeValidator_NonMessagesPathUAOnly(t *testing.T) {
ok := validator.Validate(req, nil)
require.True(t, ok)
}
+
+func TestExtractVersion(t *testing.T) {
+ v := NewClaudeCodeValidator()
+ tests := []struct {
+ ua string
+ want string
+ }{
+ {"claude-cli/2.1.22 (darwin; arm64)", "2.1.22"},
+ {"claude-cli/1.0.0", "1.0.0"},
+ {"Claude-CLI/3.10.5 (linux; x86_64)", "3.10.5"}, // 大小写不敏感
+ {"curl/8.0.0", ""}, // 非 Claude CLI
+ {"", ""}, // 空字符串
+ {"claude-cli/", ""}, // 无版本号
+ {"claude-cli/2.1.22-beta", "2.1.22"}, // 带后缀仍提取主版本号
+ }
+ for _, tt := range tests {
+ got := v.ExtractVersion(tt.ua)
+ require.Equal(t, tt.want, got, "ExtractVersion(%q)", tt.ua)
+ }
+}
+
+func TestCompareVersions(t *testing.T) {
+ tests := []struct {
+ a, b string
+ want int
+ }{
+ {"2.1.0", "2.1.0", 0}, // 相等
+ {"2.1.1", "2.1.0", 1}, // patch 更大
+ {"2.0.0", "2.1.0", -1}, // minor 更小
+ {"3.0.0", "2.99.99", 1}, // major 更大
+ {"1.0.0", "2.0.0", -1}, // major 更小
+ {"0.0.1", "0.0.0", 1}, // patch 差异
+ {"", "1.0.0", -1}, // 空字符串 vs 正常版本
+ {"v2.1.0", "2.1.0", 0}, // v 前缀处理
+ }
+ for _, tt := range tests {
+ got := CompareVersions(tt.a, tt.b)
+ require.Equal(t, tt.want, got, "CompareVersions(%q, %q)", tt.a, tt.b)
+ }
+}
+
+func TestSetGetClaudeCodeVersion(t *testing.T) {
+ ctx := context.Background()
+ require.Equal(t, "", GetClaudeCodeVersion(ctx), "empty context should return empty string")
+
+ ctx = SetClaudeCodeVersion(ctx, "2.1.63")
+ require.Equal(t, "2.1.63", GetClaudeCodeVersion(ctx))
+}
diff --git a/backend/internal/service/concurrency_service.go b/backend/internal/service/concurrency_service.go
index 32b6d97c..4dcf84e0 100644
--- a/backend/internal/service/concurrency_service.go
+++ b/backend/internal/service/concurrency_service.go
@@ -3,8 +3,10 @@ package service
import (
"context"
"crypto/rand"
- "encoding/hex"
- "fmt"
+ "encoding/binary"
+ "os"
+ "strconv"
+ "sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
@@ -18,6 +20,7 @@ type ConcurrencyCache interface {
AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error)
ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error
GetAccountConcurrency(ctx context.Context, accountID int64) (int, error)
+ GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error)
// 账号等待队列(账号级)
IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error)
@@ -42,15 +45,25 @@ type ConcurrencyCache interface {
CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error
}
-// generateRequestID generates a unique request ID for concurrency slot tracking
-// Uses 8 random bytes (16 hex chars) for uniqueness
-func generateRequestID() string {
+var (
+ requestIDPrefix = initRequestIDPrefix()
+ requestIDCounter atomic.Uint64
+)
+
+func initRequestIDPrefix() string {
b := make([]byte, 8)
- if _, err := rand.Read(b); err != nil {
- // Fallback to nanosecond timestamp (extremely rare case)
- return fmt.Sprintf("%x", time.Now().UnixNano())
+ if _, err := rand.Read(b); err == nil {
+ return "r" + strconv.FormatUint(binary.BigEndian.Uint64(b), 36)
}
- return hex.EncodeToString(b)
+ fallback := uint64(time.Now().UnixNano()) ^ (uint64(os.Getpid()) << 16)
+ return "r" + strconv.FormatUint(fallback, 36)
+}
+
+// generateRequestID generates a unique request ID for concurrency slot tracking.
+// Format: {process_random_prefix}-{base36_counter}
+func generateRequestID() string {
+ seq := requestIDCounter.Add(1)
+ return requestIDPrefix + "-" + strconv.FormatUint(seq, 36)
}
const (
@@ -321,16 +334,15 @@ func (s *ConcurrencyService) StartSlotCleanupWorker(accountRepo AccountRepositor
// GetAccountConcurrencyBatch gets current concurrency counts for multiple accounts
// Returns a map of accountID -> current concurrency count
func (s *ConcurrencyService) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {
- result := make(map[int64]int)
-
- for _, accountID := range accountIDs {
- count, err := s.cache.GetAccountConcurrency(ctx, accountID)
- if err != nil {
- // If key doesn't exist in Redis, count is 0
- count = 0
- }
- result[accountID] = count
+ if len(accountIDs) == 0 {
+ return map[int64]int{}, nil
}
-
- return result, nil
+ if s.cache == nil {
+ result := make(map[int64]int, len(accountIDs))
+ for _, accountID := range accountIDs {
+ result[accountID] = 0
+ }
+ return result, nil
+ }
+ return s.cache.GetAccountConcurrencyBatch(ctx, accountIDs)
}
diff --git a/backend/internal/service/concurrency_service_test.go b/backend/internal/service/concurrency_service_test.go
index 33ce4cb9..9ba43d93 100644
--- a/backend/internal/service/concurrency_service_test.go
+++ b/backend/internal/service/concurrency_service_test.go
@@ -5,6 +5,8 @@ package service
import (
"context"
"errors"
+ "strconv"
+ "strings"
"testing"
"github.com/stretchr/testify/require"
@@ -12,20 +14,20 @@ import (
// stubConcurrencyCacheForTest 用于并发服务单元测试的缓存桩
type stubConcurrencyCacheForTest struct {
- acquireResult bool
- acquireErr error
- releaseErr error
- concurrency int
+ acquireResult bool
+ acquireErr error
+ releaseErr error
+ concurrency int
concurrencyErr error
- waitAllowed bool
- waitErr error
- waitCount int
- waitCountErr error
- loadBatch map[int64]*AccountLoadInfo
- loadBatchErr error
+ waitAllowed bool
+ waitErr error
+ waitCount int
+ waitCountErr error
+ loadBatch map[int64]*AccountLoadInfo
+ loadBatchErr error
usersLoadBatch map[int64]*UserLoadInfo
usersLoadErr error
- cleanupErr error
+ cleanupErr error
// 记录调用
releasedAccountIDs []int64
@@ -45,6 +47,16 @@ func (c *stubConcurrencyCacheForTest) ReleaseAccountSlot(_ context.Context, acco
func (c *stubConcurrencyCacheForTest) GetAccountConcurrency(_ context.Context, _ int64) (int, error) {
return c.concurrency, c.concurrencyErr
}
+func (c *stubConcurrencyCacheForTest) GetAccountConcurrencyBatch(_ context.Context, accountIDs []int64) (map[int64]int, error) {
+ result := make(map[int64]int, len(accountIDs))
+ for _, accountID := range accountIDs {
+ if c.concurrencyErr != nil {
+ return nil, c.concurrencyErr
+ }
+ result[accountID] = c.concurrency
+ }
+ return result, nil
+}
func (c *stubConcurrencyCacheForTest) IncrementAccountWaitCount(_ context.Context, _ int64, _ int) (bool, error) {
return c.waitAllowed, c.waitErr
}
@@ -155,6 +167,25 @@ func TestAcquireUserSlot_UnlimitedConcurrency(t *testing.T) {
require.True(t, result.Acquired)
}
+func TestGenerateRequestID_UsesStablePrefixAndMonotonicCounter(t *testing.T) {
+ id1 := generateRequestID()
+ id2 := generateRequestID()
+ require.NotEmpty(t, id1)
+ require.NotEmpty(t, id2)
+
+ p1 := strings.Split(id1, "-")
+ p2 := strings.Split(id2, "-")
+ require.Len(t, p1, 2)
+ require.Len(t, p2, 2)
+ require.Equal(t, p1[0], p2[0], "同一进程前缀应保持一致")
+
+ n1, err := strconv.ParseUint(p1[1], 36, 64)
+ require.NoError(t, err)
+ n2, err := strconv.ParseUint(p2[1], 36, 64)
+ require.NoError(t, err)
+ require.Equal(t, n1+1, n2, "计数器应单调递增")
+}
+
func TestGetAccountsLoadBatch_ReturnsCorrectData(t *testing.T) {
expected := map[int64]*AccountLoadInfo{
1: {AccountID: 1, CurrentConcurrency: 3, WaitingCount: 0, LoadRate: 60},
diff --git a/backend/internal/service/dashboard_service.go b/backend/internal/service/dashboard_service.go
index 9aab10d2..2af43386 100644
--- a/backend/internal/service/dashboard_service.go
+++ b/backend/internal/service/dashboard_service.go
@@ -124,22 +124,30 @@ func (s *DashboardService) GetDashboardStats(ctx context.Context) (*usagestats.D
return stats, nil
}
-func (s *DashboardService) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) {
- trend, err := s.usageRepo.GetUsageTrendWithFilters(ctx, startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, stream, billingType)
+func (s *DashboardService) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) {
+ trend, err := s.usageRepo.GetUsageTrendWithFilters(ctx, startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType)
if err != nil {
return nil, fmt.Errorf("get usage trend with filters: %w", err)
}
return trend, nil
}
-func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) {
- stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, stream, billingType)
+func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) {
+ stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
if err != nil {
return nil, fmt.Errorf("get model stats with filters: %w", err)
}
return stats, nil
}
+func (s *DashboardService) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) {
+ stats, err := s.usageRepo.GetGroupStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
+ if err != nil {
+ return nil, fmt.Errorf("get group stats with filters: %w", err)
+ }
+ return stats, nil
+}
+
func (s *DashboardService) getCachedDashboardStats(ctx context.Context) (*usagestats.DashboardStats, bool, error) {
data, err := s.cache.GetDashboardStats(ctx)
if err != nil {
diff --git a/backend/internal/service/data_management_grpc.go b/backend/internal/service/data_management_grpc.go
new file mode 100644
index 00000000..aeb3d529
--- /dev/null
+++ b/backend/internal/service/data_management_grpc.go
@@ -0,0 +1,252 @@
+package service
+
+import "context"
+
+type DataManagementPostgresConfig struct {
+ Host string `json:"host"`
+ Port int32 `json:"port"`
+ User string `json:"user"`
+ Password string `json:"password,omitempty"`
+ PasswordConfigured bool `json:"password_configured"`
+ Database string `json:"database"`
+ SSLMode string `json:"ssl_mode"`
+ ContainerName string `json:"container_name"`
+}
+
+type DataManagementRedisConfig struct {
+ Addr string `json:"addr"`
+ Username string `json:"username"`
+ Password string `json:"password,omitempty"`
+ PasswordConfigured bool `json:"password_configured"`
+ DB int32 `json:"db"`
+ ContainerName string `json:"container_name"`
+}
+
+type DataManagementS3Config struct {
+ Enabled bool `json:"enabled"`
+ Endpoint string `json:"endpoint"`
+ Region string `json:"region"`
+ Bucket string `json:"bucket"`
+ AccessKeyID string `json:"access_key_id"`
+ SecretAccessKey string `json:"secret_access_key,omitempty"`
+ SecretAccessKeyConfigured bool `json:"secret_access_key_configured"`
+ Prefix string `json:"prefix"`
+ ForcePathStyle bool `json:"force_path_style"`
+ UseSSL bool `json:"use_ssl"`
+}
+
+type DataManagementConfig struct {
+ SourceMode string `json:"source_mode"`
+ BackupRoot string `json:"backup_root"`
+ SQLitePath string `json:"sqlite_path,omitempty"`
+ RetentionDays int32 `json:"retention_days"`
+ KeepLast int32 `json:"keep_last"`
+ ActivePostgresID string `json:"active_postgres_profile_id"`
+ ActiveRedisID string `json:"active_redis_profile_id"`
+ Postgres DataManagementPostgresConfig `json:"postgres"`
+ Redis DataManagementRedisConfig `json:"redis"`
+ S3 DataManagementS3Config `json:"s3"`
+ ActiveS3ProfileID string `json:"active_s3_profile_id"`
+}
+
+type DataManagementTestS3Result struct {
+ OK bool `json:"ok"`
+ Message string `json:"message"`
+}
+
+type DataManagementCreateBackupJobInput struct {
+ BackupType string
+ UploadToS3 bool
+ TriggeredBy string
+ IdempotencyKey string
+ S3ProfileID string
+ PostgresID string
+ RedisID string
+}
+
+type DataManagementListBackupJobsInput struct {
+ PageSize int32
+ PageToken string
+ Status string
+ BackupType string
+}
+
+type DataManagementArtifactInfo struct {
+ LocalPath string `json:"local_path"`
+ SizeBytes int64 `json:"size_bytes"`
+ SHA256 string `json:"sha256"`
+}
+
+type DataManagementS3ObjectInfo struct {
+ Bucket string `json:"bucket"`
+ Key string `json:"key"`
+ ETag string `json:"etag"`
+}
+
+type DataManagementBackupJob struct {
+ JobID string `json:"job_id"`
+ BackupType string `json:"backup_type"`
+ Status string `json:"status"`
+ TriggeredBy string `json:"triggered_by"`
+ IdempotencyKey string `json:"idempotency_key,omitempty"`
+ UploadToS3 bool `json:"upload_to_s3"`
+ S3ProfileID string `json:"s3_profile_id,omitempty"`
+ PostgresID string `json:"postgres_profile_id,omitempty"`
+ RedisID string `json:"redis_profile_id,omitempty"`
+ StartedAt string `json:"started_at,omitempty"`
+ FinishedAt string `json:"finished_at,omitempty"`
+ ErrorMessage string `json:"error_message,omitempty"`
+ Artifact DataManagementArtifactInfo `json:"artifact"`
+ S3Object DataManagementS3ObjectInfo `json:"s3"`
+}
+
+type DataManagementSourceProfile struct {
+ SourceType string `json:"source_type"`
+ ProfileID string `json:"profile_id"`
+ Name string `json:"name"`
+ IsActive bool `json:"is_active"`
+ Config DataManagementSourceConfig `json:"config"`
+ PasswordConfigured bool `json:"password_configured"`
+ CreatedAt string `json:"created_at,omitempty"`
+ UpdatedAt string `json:"updated_at,omitempty"`
+}
+
+type DataManagementSourceConfig struct {
+ Host string `json:"host"`
+ Port int32 `json:"port"`
+ User string `json:"user"`
+ Password string `json:"password,omitempty"`
+ Database string `json:"database"`
+ SSLMode string `json:"ssl_mode"`
+ Addr string `json:"addr"`
+ Username string `json:"username"`
+ DB int32 `json:"db"`
+ ContainerName string `json:"container_name"`
+}
+
+type DataManagementCreateSourceProfileInput struct {
+ SourceType string
+ ProfileID string
+ Name string
+ Config DataManagementSourceConfig
+ SetActive bool
+}
+
+type DataManagementUpdateSourceProfileInput struct {
+ SourceType string
+ ProfileID string
+ Name string
+ Config DataManagementSourceConfig
+}
+
+type DataManagementS3Profile struct {
+ ProfileID string `json:"profile_id"`
+ Name string `json:"name"`
+ IsActive bool `json:"is_active"`
+ S3 DataManagementS3Config `json:"s3"`
+ SecretAccessKeyConfigured bool `json:"secret_access_key_configured"`
+ CreatedAt string `json:"created_at,omitempty"`
+ UpdatedAt string `json:"updated_at,omitempty"`
+}
+
+type DataManagementCreateS3ProfileInput struct {
+ ProfileID string
+ Name string
+ S3 DataManagementS3Config
+ SetActive bool
+}
+
+type DataManagementUpdateS3ProfileInput struct {
+ ProfileID string
+ Name string
+ S3 DataManagementS3Config
+}
+
+type DataManagementListBackupJobsResult struct {
+ Items []DataManagementBackupJob `json:"items"`
+ NextPageToken string `json:"next_page_token,omitempty"`
+}
+
+func (s *DataManagementService) GetConfig(ctx context.Context) (DataManagementConfig, error) {
+ _ = ctx
+ return DataManagementConfig{}, s.deprecatedError()
+}
+
+func (s *DataManagementService) UpdateConfig(ctx context.Context, cfg DataManagementConfig) (DataManagementConfig, error) {
+ _, _ = ctx, cfg
+ return DataManagementConfig{}, s.deprecatedError()
+}
+
+func (s *DataManagementService) ListSourceProfiles(ctx context.Context, sourceType string) ([]DataManagementSourceProfile, error) {
+ _, _ = ctx, sourceType
+ return nil, s.deprecatedError()
+}
+
+func (s *DataManagementService) CreateSourceProfile(ctx context.Context, input DataManagementCreateSourceProfileInput) (DataManagementSourceProfile, error) {
+ _, _ = ctx, input
+ return DataManagementSourceProfile{}, s.deprecatedError()
+}
+
+func (s *DataManagementService) UpdateSourceProfile(ctx context.Context, input DataManagementUpdateSourceProfileInput) (DataManagementSourceProfile, error) {
+ _, _ = ctx, input
+ return DataManagementSourceProfile{}, s.deprecatedError()
+}
+
+func (s *DataManagementService) DeleteSourceProfile(ctx context.Context, sourceType, profileID string) error {
+ _, _, _ = ctx, sourceType, profileID
+ return s.deprecatedError()
+}
+
+func (s *DataManagementService) SetActiveSourceProfile(ctx context.Context, sourceType, profileID string) (DataManagementSourceProfile, error) {
+ _, _, _ = ctx, sourceType, profileID
+ return DataManagementSourceProfile{}, s.deprecatedError()
+}
+
+func (s *DataManagementService) ValidateS3(ctx context.Context, cfg DataManagementS3Config) (DataManagementTestS3Result, error) {
+ _, _ = ctx, cfg
+ return DataManagementTestS3Result{}, s.deprecatedError()
+}
+
+func (s *DataManagementService) ListS3Profiles(ctx context.Context) ([]DataManagementS3Profile, error) {
+ _ = ctx
+ return nil, s.deprecatedError()
+}
+
+func (s *DataManagementService) CreateS3Profile(ctx context.Context, input DataManagementCreateS3ProfileInput) (DataManagementS3Profile, error) {
+ _, _ = ctx, input
+ return DataManagementS3Profile{}, s.deprecatedError()
+}
+
+func (s *DataManagementService) UpdateS3Profile(ctx context.Context, input DataManagementUpdateS3ProfileInput) (DataManagementS3Profile, error) {
+ _, _ = ctx, input
+ return DataManagementS3Profile{}, s.deprecatedError()
+}
+
+func (s *DataManagementService) DeleteS3Profile(ctx context.Context, profileID string) error {
+ _, _ = ctx, profileID
+ return s.deprecatedError()
+}
+
+func (s *DataManagementService) SetActiveS3Profile(ctx context.Context, profileID string) (DataManagementS3Profile, error) {
+ _, _ = ctx, profileID
+ return DataManagementS3Profile{}, s.deprecatedError()
+}
+
+func (s *DataManagementService) CreateBackupJob(ctx context.Context, input DataManagementCreateBackupJobInput) (DataManagementBackupJob, error) {
+ _, _ = ctx, input
+ return DataManagementBackupJob{}, s.deprecatedError()
+}
+
+func (s *DataManagementService) ListBackupJobs(ctx context.Context, input DataManagementListBackupJobsInput) (DataManagementListBackupJobsResult, error) {
+ _, _ = ctx, input
+ return DataManagementListBackupJobsResult{}, s.deprecatedError()
+}
+
+func (s *DataManagementService) GetBackupJob(ctx context.Context, jobID string) (DataManagementBackupJob, error) {
+ _, _ = ctx, jobID
+ return DataManagementBackupJob{}, s.deprecatedError()
+}
+
+func (s *DataManagementService) deprecatedError() error {
+ return ErrDataManagementDeprecated.WithMetadata(map[string]string{"socket_path": s.SocketPath()})
+}
diff --git a/backend/internal/service/data_management_grpc_test.go b/backend/internal/service/data_management_grpc_test.go
new file mode 100644
index 00000000..286eb58d
--- /dev/null
+++ b/backend/internal/service/data_management_grpc_test.go
@@ -0,0 +1,36 @@
+package service
+
+import (
+ "context"
+ "path/filepath"
+ "testing"
+
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
+ "github.com/stretchr/testify/require"
+)
+
+func TestDataManagementService_DeprecatedRPCMethods(t *testing.T) {
+ t.Parallel()
+
+ socketPath := filepath.Join(t.TempDir(), "datamanagement.sock")
+ svc := NewDataManagementServiceWithOptions(socketPath, 0)
+
+ _, err := svc.GetConfig(context.Background())
+ assertDeprecatedDataManagementError(t, err, socketPath)
+
+ _, err = svc.CreateBackupJob(context.Background(), DataManagementCreateBackupJobInput{BackupType: "full"})
+ assertDeprecatedDataManagementError(t, err, socketPath)
+
+ err = svc.DeleteS3Profile(context.Background(), "s3-default")
+ assertDeprecatedDataManagementError(t, err, socketPath)
+}
+
+func assertDeprecatedDataManagementError(t *testing.T, err error, socketPath string) {
+ t.Helper()
+
+ require.Error(t, err)
+ statusCode, status := infraerrors.ToHTTP(err)
+ require.Equal(t, 503, statusCode)
+ require.Equal(t, DataManagementDeprecatedReason, status.Reason)
+ require.Equal(t, socketPath, status.Metadata["socket_path"])
+}
diff --git a/backend/internal/service/data_management_service.go b/backend/internal/service/data_management_service.go
new file mode 100644
index 00000000..b525c0fa
--- /dev/null
+++ b/backend/internal/service/data_management_service.go
@@ -0,0 +1,95 @@
+package service
+
+import (
+ "context"
+ "strings"
+ "time"
+
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
+)
+
+const (
+ DefaultDataManagementAgentSocketPath = "/tmp/sub2api-datamanagement.sock"
+ LegacyBackupAgentSocketPath = "/tmp/sub2api-backup.sock"
+
+ DataManagementDeprecatedReason = "DATA_MANAGEMENT_DEPRECATED"
+ DataManagementAgentSocketMissingReason = "DATA_MANAGEMENT_AGENT_SOCKET_MISSING"
+ DataManagementAgentUnavailableReason = "DATA_MANAGEMENT_AGENT_UNAVAILABLE"
+
+ // Deprecated: keep old names for compatibility.
+ DefaultBackupAgentSocketPath = DefaultDataManagementAgentSocketPath
+ BackupAgentSocketMissingReason = DataManagementAgentSocketMissingReason
+ BackupAgentUnavailableReason = DataManagementAgentUnavailableReason
+)
+
+var (
+ ErrDataManagementDeprecated = infraerrors.ServiceUnavailable(
+ DataManagementDeprecatedReason,
+ "data management feature is deprecated",
+ )
+ ErrDataManagementAgentSocketMissing = infraerrors.ServiceUnavailable(
+ DataManagementAgentSocketMissingReason,
+ "data management agent socket is missing",
+ )
+ ErrDataManagementAgentUnavailable = infraerrors.ServiceUnavailable(
+ DataManagementAgentUnavailableReason,
+ "data management agent is unavailable",
+ )
+
+ // Deprecated: keep old names for compatibility.
+ ErrBackupAgentSocketMissing = ErrDataManagementAgentSocketMissing
+ ErrBackupAgentUnavailable = ErrDataManagementAgentUnavailable
+)
+
+type DataManagementAgentHealth struct {
+ Enabled bool
+ Reason string
+ SocketPath string
+ Agent *DataManagementAgentInfo
+}
+
+type DataManagementAgentInfo struct {
+ Status string
+ Version string
+ UptimeSeconds int64
+}
+
+type DataManagementService struct {
+ socketPath string
+}
+
+func NewDataManagementService() *DataManagementService {
+ return NewDataManagementServiceWithOptions(DefaultDataManagementAgentSocketPath, 500*time.Millisecond)
+}
+
+func NewDataManagementServiceWithOptions(socketPath string, dialTimeout time.Duration) *DataManagementService {
+ _ = dialTimeout
+ path := strings.TrimSpace(socketPath)
+ if path == "" {
+ path = DefaultDataManagementAgentSocketPath
+ }
+ return &DataManagementService{
+ socketPath: path,
+ }
+}
+
+func (s *DataManagementService) SocketPath() string {
+ if s == nil || strings.TrimSpace(s.socketPath) == "" {
+ return DefaultDataManagementAgentSocketPath
+ }
+ return s.socketPath
+}
+
+func (s *DataManagementService) GetAgentHealth(ctx context.Context) DataManagementAgentHealth {
+ _ = ctx
+ return DataManagementAgentHealth{
+ Enabled: false,
+ Reason: DataManagementDeprecatedReason,
+ SocketPath: s.SocketPath(),
+ }
+}
+
+func (s *DataManagementService) EnsureAgentEnabled(ctx context.Context) error {
+ _ = ctx
+ return ErrDataManagementDeprecated.WithMetadata(map[string]string{"socket_path": s.SocketPath()})
+}
diff --git a/backend/internal/service/data_management_service_test.go b/backend/internal/service/data_management_service_test.go
new file mode 100644
index 00000000..65489d2e
--- /dev/null
+++ b/backend/internal/service/data_management_service_test.go
@@ -0,0 +1,37 @@
+package service
+
+import (
+ "context"
+ "path/filepath"
+ "testing"
+
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
+ "github.com/stretchr/testify/require"
+)
+
+func TestDataManagementService_GetAgentHealth_Deprecated(t *testing.T) {
+ t.Parallel()
+
+ socketPath := filepath.Join(t.TempDir(), "unused.sock")
+ svc := NewDataManagementServiceWithOptions(socketPath, 0)
+ health := svc.GetAgentHealth(context.Background())
+
+ require.False(t, health.Enabled)
+ require.Equal(t, DataManagementDeprecatedReason, health.Reason)
+ require.Equal(t, socketPath, health.SocketPath)
+ require.Nil(t, health.Agent)
+}
+
+func TestDataManagementService_EnsureAgentEnabled_Deprecated(t *testing.T) {
+ t.Parallel()
+
+ socketPath := filepath.Join(t.TempDir(), "unused.sock")
+ svc := NewDataManagementServiceWithOptions(socketPath, 100)
+ err := svc.EnsureAgentEnabled(context.Background())
+ require.Error(t, err)
+
+ statusCode, status := infraerrors.ToHTTP(err)
+ require.Equal(t, 503, statusCode)
+ require.Equal(t, DataManagementDeprecatedReason, status.Reason)
+ require.Equal(t, socketPath, status.Metadata["socket_path"])
+}
diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go
index ceae443f..b304bc9f 100644
--- a/backend/internal/service/domain_constants.go
+++ b/backend/internal/service/domain_constants.go
@@ -104,6 +104,7 @@ const (
SettingKeyLinuxDoConnectRedirectURL = "linuxdo_connect_redirect_url"
// OEM设置
+ SettingKeySoraClientEnabled = "sora_client_enabled" // 是否启用 Sora 客户端(管理员手动控制)
SettingKeySiteName = "site_name" // 网站名称
SettingKeySiteLogo = "site_logo" // 网站Logo (base64)
SettingKeySiteSubtitle = "site_subtitle" // 网站副标题
@@ -116,8 +117,9 @@ const (
SettingKeyPurchaseSubscriptionURL = "purchase_subscription_url" // “购买订阅”页面 URL(作为 iframe src)
// 默认配置
- SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量
- SettingKeyDefaultBalance = "default_balance" // 新用户默认余额
+ SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量
+ SettingKeyDefaultBalance = "default_balance" // 新用户默认余额
+ SettingKeyDefaultSubscriptions = "default_subscriptions" // 新用户默认订阅列表(JSON)
// 管理员 API Key
SettingKeyAdminAPIKey = "admin_api_key" // 全局管理员 API Key(用于外部系统集成)
@@ -170,6 +172,34 @@ const (
// SettingKeyStreamTimeoutSettings stores JSON config for stream timeout handling.
SettingKeyStreamTimeoutSettings = "stream_timeout_settings"
+
+ // =========================
+ // Sora S3 存储配置
+ // =========================
+
+ SettingKeySoraS3Enabled = "sora_s3_enabled" // 是否启用 Sora S3 存储
+ SettingKeySoraS3Endpoint = "sora_s3_endpoint" // S3 端点地址
+ SettingKeySoraS3Region = "sora_s3_region" // S3 区域
+ SettingKeySoraS3Bucket = "sora_s3_bucket" // S3 存储桶名称
+ SettingKeySoraS3AccessKeyID = "sora_s3_access_key_id" // S3 Access Key ID
+ SettingKeySoraS3SecretAccessKey = "sora_s3_secret_access_key" // S3 Secret Access Key(加密存储)
+ SettingKeySoraS3Prefix = "sora_s3_prefix" // S3 对象键前缀
+ SettingKeySoraS3ForcePathStyle = "sora_s3_force_path_style" // 是否强制 Path Style(兼容 MinIO 等)
+ SettingKeySoraS3CDNURL = "sora_s3_cdn_url" // CDN 加速 URL(可选)
+ SettingKeySoraS3Profiles = "sora_s3_profiles" // Sora S3 多配置(JSON)
+
+ // =========================
+ // Sora 用户存储配额
+ // =========================
+
+ SettingKeySoraDefaultStorageQuotaBytes = "sora_default_storage_quota_bytes" // 新用户默认 Sora 存储配额(字节)
+
+ // =========================
+ // Claude Code Version Check
+ // =========================
+
+ // SettingKeyMinClaudeCodeVersion 最低 Claude Code 版本号要求 (semver, 如 "2.1.0",空值=不检查)
+ SettingKeyMinClaudeCodeVersion = "min_claude_code_version"
)
// AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys).
diff --git a/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go b/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go
index 5183891b..f8c0ecda 100644
--- a/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go
+++ b/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go
@@ -4,6 +4,7 @@ import (
"bufio"
"bytes"
"context"
+ "encoding/json"
"errors"
"io"
"net/http"
@@ -262,6 +263,107 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardCountTokensPreservesBo
require.Empty(t, rec.Header().Get("Set-Cookie"))
}
+func TestGatewayService_AnthropicAPIKeyPassthrough_CountTokens404PassthroughNotError(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ tests := []struct {
+ name string
+ statusCode int
+ respBody string
+ wantPassthrough bool
+ }{
+ {
+ name: "404 endpoint not found passes through as 404",
+ statusCode: http.StatusNotFound,
+ respBody: `{"error":{"message":"Not found: /v1/messages/count_tokens","type":"not_found_error"}}`,
+ wantPassthrough: true,
+ },
+ {
+ name: "404 generic not found does not passthrough",
+ statusCode: http.StatusNotFound,
+ respBody: `{"error":{"message":"resource not found","type":"not_found_error"}}`,
+ wantPassthrough: false,
+ },
+ {
+ name: "400 Invalid URL does not passthrough",
+ statusCode: http.StatusBadRequest,
+ respBody: `{"error":{"message":"Invalid URL (POST /v1/messages/count_tokens)","type":"invalid_request_error"}}`,
+ wantPassthrough: false,
+ },
+ {
+ name: "400 model error does not passthrough",
+ statusCode: http.StatusBadRequest,
+ respBody: `{"error":{"message":"model not found: claude-unknown","type":"invalid_request_error"}}`,
+ wantPassthrough: false,
+ },
+ {
+ name: "500 internal error does not passthrough",
+ statusCode: http.StatusInternalServerError,
+ respBody: `{"error":{"message":"internal error","type":"api_error"}}`,
+ wantPassthrough: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages/count_tokens", nil)
+
+ body := []byte(`{"model":"claude-sonnet-4-5-20250929","messages":[{"role":"user","content":"hi"}]}`)
+ parsed := &ParsedRequest{Body: body, Model: "claude-sonnet-4-5-20250929"}
+
+ upstream := &anthropicHTTPUpstreamRecorder{
+ resp: &http.Response{
+ StatusCode: tt.statusCode,
+ Header: http.Header{"Content-Type": []string{"application/json"}},
+ Body: io.NopCloser(strings.NewReader(tt.respBody)),
+ },
+ }
+
+ svc := &GatewayService{
+ cfg: &config.Config{
+ Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
+ },
+ httpUpstream: upstream,
+ rateLimitService: nil,
+ }
+
+ account := &Account{
+ ID: 200,
+ Name: "proxy-acc",
+ Platform: PlatformAnthropic,
+ Type: AccountTypeAPIKey,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "api_key": "sk-proxy",
+ "base_url": "https://proxy.example.com",
+ },
+ Extra: map[string]any{"anthropic_passthrough": true},
+ Status: StatusActive,
+ Schedulable: true,
+ }
+
+ err := svc.ForwardCountTokens(context.Background(), c, account, parsed)
+
+ if tt.wantPassthrough {
+ // 返回 nil(不记录为错误),HTTP 状态码 404 + Anthropic 错误体
+ require.NoError(t, err)
+ require.Equal(t, http.StatusNotFound, rec.Code)
+ var errResp map[string]any
+ require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &errResp))
+ require.Equal(t, "error", errResp["type"])
+ errObj, ok := errResp["error"].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "not_found_error", errObj["type"])
+ } else {
+ require.Error(t, err)
+ require.Equal(t, tt.statusCode, rec.Code)
+ }
+ })
+ }
+}
+
func TestGatewayService_AnthropicAPIKeyPassthrough_BuildRequestRejectsInvalidBaseURL(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
diff --git a/backend/internal/service/gateway_beta_test.go b/backend/internal/service/gateway_beta_test.go
index d7108c8d..21a1faa4 100644
--- a/backend/internal/service/gateway_beta_test.go
+++ b/backend/internal/service/gateway_beta_test.go
@@ -3,6 +3,8 @@ package service
import (
"testing"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/claude"
+
"github.com/stretchr/testify/require"
)
@@ -22,60 +24,78 @@ func TestMergeAnthropicBeta_EmptyIncoming(t *testing.T) {
require.Equal(t, "oauth-2025-04-20,interleaved-thinking-2025-05-14", got)
}
-func TestStripBetaToken(t *testing.T) {
+func TestStripBetaTokens(t *testing.T) {
tests := []struct {
name string
header string
- token string
+ tokens []string
want string
}{
{
- name: "token in middle",
+ name: "single token in middle",
header: "oauth-2025-04-20,context-1m-2025-08-07,interleaved-thinking-2025-05-14",
- token: "context-1m-2025-08-07",
+ tokens: []string{"context-1m-2025-08-07"},
want: "oauth-2025-04-20,interleaved-thinking-2025-05-14",
},
{
- name: "token at start",
+ name: "single token at start",
header: "context-1m-2025-08-07,oauth-2025-04-20,interleaved-thinking-2025-05-14",
- token: "context-1m-2025-08-07",
+ tokens: []string{"context-1m-2025-08-07"},
want: "oauth-2025-04-20,interleaved-thinking-2025-05-14",
},
{
- name: "token at end",
+ name: "single token at end",
header: "oauth-2025-04-20,interleaved-thinking-2025-05-14,context-1m-2025-08-07",
- token: "context-1m-2025-08-07",
+ tokens: []string{"context-1m-2025-08-07"},
want: "oauth-2025-04-20,interleaved-thinking-2025-05-14",
},
{
name: "token not present",
header: "oauth-2025-04-20,interleaved-thinking-2025-05-14",
- token: "context-1m-2025-08-07",
+ tokens: []string{"context-1m-2025-08-07"},
want: "oauth-2025-04-20,interleaved-thinking-2025-05-14",
},
{
name: "empty header",
header: "",
- token: "context-1m-2025-08-07",
+ tokens: []string{"context-1m-2025-08-07"},
want: "",
},
{
name: "with spaces",
header: "oauth-2025-04-20, context-1m-2025-08-07 , interleaved-thinking-2025-05-14",
- token: "context-1m-2025-08-07",
+ tokens: []string{"context-1m-2025-08-07"},
want: "oauth-2025-04-20,interleaved-thinking-2025-05-14",
},
{
name: "only token",
header: "context-1m-2025-08-07",
- token: "context-1m-2025-08-07",
+ tokens: []string{"context-1m-2025-08-07"},
want: "",
},
+ {
+ name: "nil tokens",
+ header: "oauth-2025-04-20,interleaved-thinking-2025-05-14",
+ tokens: nil,
+ want: "oauth-2025-04-20,interleaved-thinking-2025-05-14",
+ },
+ {
+ name: "multiple tokens removed",
+ header: "oauth-2025-04-20,context-1m-2025-08-07,interleaved-thinking-2025-05-14,fast-mode-2026-02-01",
+ tokens: []string{"context-1m-2025-08-07", "fast-mode-2026-02-01"},
+ want: "oauth-2025-04-20,interleaved-thinking-2025-05-14",
+ },
+ {
+ name: "DroppedBetas removes both context-1m and fast-mode",
+ header: "oauth-2025-04-20,context-1m-2025-08-07,fast-mode-2026-02-01,interleaved-thinking-2025-05-14",
+ tokens: claude.DroppedBetas,
+ want: "oauth-2025-04-20,interleaved-thinking-2025-05-14",
+ },
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- got := stripBetaToken(tt.header, tt.token)
+ got := stripBetaTokens(tt.header, tt.tokens)
require.Equal(t, tt.want, got)
})
}
@@ -90,3 +110,93 @@ func TestMergeAnthropicBetaDropping_Context1M(t *testing.T) {
require.Equal(t, "oauth-2025-04-20,interleaved-thinking-2025-05-14,foo-beta", got)
require.NotContains(t, got, "context-1m-2025-08-07")
}
+
+func TestMergeAnthropicBetaDropping_DroppedBetas(t *testing.T) {
+ required := []string{"oauth-2025-04-20", "interleaved-thinking-2025-05-14"}
+ incoming := "context-1m-2025-08-07,fast-mode-2026-02-01,foo-beta,oauth-2025-04-20"
+ drop := droppedBetaSet()
+
+ got := mergeAnthropicBetaDropping(required, incoming, drop)
+ require.Equal(t, "oauth-2025-04-20,interleaved-thinking-2025-05-14,foo-beta", got)
+ require.NotContains(t, got, "context-1m-2025-08-07")
+ require.NotContains(t, got, "fast-mode-2026-02-01")
+}
+
+func TestDroppedBetaSet(t *testing.T) {
+ // Base set contains DroppedBetas
+ base := droppedBetaSet()
+ require.Contains(t, base, claude.BetaContext1M)
+ require.Contains(t, base, claude.BetaFastMode)
+ require.Len(t, base, len(claude.DroppedBetas))
+
+ // With extra tokens
+ extended := droppedBetaSet(claude.BetaClaudeCode)
+ require.Contains(t, extended, claude.BetaContext1M)
+ require.Contains(t, extended, claude.BetaFastMode)
+ require.Contains(t, extended, claude.BetaClaudeCode)
+ require.Len(t, extended, len(claude.DroppedBetas)+1)
+}
+
+func TestBuildBetaTokenSet(t *testing.T) {
+ got := buildBetaTokenSet([]string{"foo", "", "bar", "foo"})
+ require.Len(t, got, 2)
+ require.Contains(t, got, "foo")
+ require.Contains(t, got, "bar")
+ require.NotContains(t, got, "")
+
+ empty := buildBetaTokenSet(nil)
+ require.Empty(t, empty)
+}
+
+func TestStripBetaTokensWithSet_EmptyDropSet(t *testing.T) {
+ header := "oauth-2025-04-20,interleaved-thinking-2025-05-14"
+ got := stripBetaTokensWithSet(header, map[string]struct{}{})
+ require.Equal(t, header, got)
+}
+
+func TestIsCountTokensUnsupported404(t *testing.T) {
+ tests := []struct {
+ name string
+ statusCode int
+ body string
+ want bool
+ }{
+ {
+ name: "exact endpoint not found",
+ statusCode: 404,
+ body: `{"error":{"message":"Not found: /v1/messages/count_tokens","type":"not_found_error"}}`,
+ want: true,
+ },
+ {
+ name: "contains count_tokens and not found",
+ statusCode: 404,
+ body: `{"error":{"message":"count_tokens route not found","type":"not_found_error"}}`,
+ want: true,
+ },
+ {
+ name: "generic 404",
+ statusCode: 404,
+ body: `{"error":{"message":"resource not found","type":"not_found_error"}}`,
+ want: false,
+ },
+ {
+ name: "404 with empty error message",
+ statusCode: 404,
+ body: `{"error":{"message":"","type":"not_found_error"}}`,
+ want: false,
+ },
+ {
+ name: "non-404 status",
+ statusCode: 400,
+ body: `{"error":{"message":"Not found: /v1/messages/count_tokens","type":"invalid_request_error"}}`,
+ want: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := isCountTokensUnsupported404(tt.statusCode, []byte(tt.body))
+ require.Equal(t, tt.want, got)
+ })
+ }
+}
diff --git a/backend/internal/service/gateway_multiplatform_test.go b/backend/internal/service/gateway_multiplatform_test.go
index 5055eec0..067a0e08 100644
--- a/backend/internal/service/gateway_multiplatform_test.go
+++ b/backend/internal/service/gateway_multiplatform_test.go
@@ -1892,6 +1892,14 @@ func (m *mockConcurrencyCache) GetAccountConcurrency(ctx context.Context, accoun
return 0, nil
}
+func (m *mockConcurrencyCache) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {
+ result := make(map[int64]int, len(accountIDs))
+ for _, accountID := range accountIDs {
+ result[accountID] = 0
+ }
+ return result, nil
+}
+
func (m *mockConcurrencyCache) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
return true, nil
}
diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go
index 5c14e7f9..3323f868 100644
--- a/backend/internal/service/gateway_service.go
+++ b/backend/internal/service/gateway_service.go
@@ -127,13 +127,26 @@ func WithForceCacheBilling(ctx context.Context) context.Context {
}
func (s *GatewayService) debugModelRoutingEnabled() bool {
- v := strings.ToLower(strings.TrimSpace(os.Getenv("SUB2API_DEBUG_MODEL_ROUTING")))
- return v == "1" || v == "true" || v == "yes" || v == "on"
+ if s == nil {
+ return false
+ }
+ return s.debugModelRouting.Load()
}
func (s *GatewayService) debugClaudeMimicEnabled() bool {
- v := strings.ToLower(strings.TrimSpace(os.Getenv("SUB2API_DEBUG_CLAUDE_MIMIC")))
- return v == "1" || v == "true" || v == "yes" || v == "on"
+ if s == nil {
+ return false
+ }
+ return s.debugClaudeMimic.Load()
+}
+
+func parseDebugEnvBool(raw string) bool {
+ switch strings.ToLower(strings.TrimSpace(raw)) {
+ case "1", "true", "yes", "on":
+ return true
+ default:
+ return false
+ }
}
func shortSessionHash(sessionHash string) string {
@@ -374,37 +387,16 @@ func modelsListCacheKey(groupID *int64, platform string) string {
}
func prefetchedStickyGroupIDFromContext(ctx context.Context) (int64, bool) {
- if ctx == nil {
- return 0, false
- }
- v := ctx.Value(ctxkey.PrefetchedStickyGroupID)
- switch t := v.(type) {
- case int64:
- return t, true
- case int:
- return int64(t), true
- }
- return 0, false
+ return PrefetchedStickyGroupIDFromContext(ctx)
}
func prefetchedStickyAccountIDFromContext(ctx context.Context, groupID *int64) int64 {
- if ctx == nil {
- return 0
- }
prefetchedGroupID, ok := prefetchedStickyGroupIDFromContext(ctx)
if !ok || prefetchedGroupID != derefGroupID(groupID) {
return 0
}
- v := ctx.Value(ctxkey.PrefetchedStickyAccountID)
- switch t := v.(type) {
- case int64:
- if t > 0 {
- return t
- }
- case int:
- if t > 0 {
- return int64(t)
- }
+ if accountID, ok := PrefetchedStickyAccountIDFromContext(ctx); ok && accountID > 0 {
+ return accountID
}
return 0
}
@@ -470,7 +462,7 @@ type ForwardResult struct {
FirstTokenMs *int // 首字时间(流式请求)
ClientDisconnect bool // 客户端是否在流式传输过程中断开
- // 图片生成计费字段(仅 gemini-3-pro-image 使用)
+ // 图片生成计费字段(图片生成模型使用)
ImageCount int // 生成的图片数量
ImageSize string // 图片尺寸 "1K", "2K", "4K"
@@ -509,29 +501,33 @@ func (s *GatewayService) TempUnscheduleRetryableError(ctx context.Context, accou
// GatewayService handles API gateway operations
type GatewayService struct {
- accountRepo AccountRepository
- groupRepo GroupRepository
- usageLogRepo UsageLogRepository
- userRepo UserRepository
- userSubRepo UserSubscriptionRepository
- userGroupRateRepo UserGroupRateRepository
- cache GatewayCache
- digestStore *DigestSessionStore
- cfg *config.Config
- schedulerSnapshot *SchedulerSnapshotService
- billingService *BillingService
- rateLimitService *RateLimitService
- billingCacheService *BillingCacheService
- identityService *IdentityService
- httpUpstream HTTPUpstream
- deferredService *DeferredService
- concurrencyService *ConcurrencyService
- claudeTokenProvider *ClaudeTokenProvider
- sessionLimitCache SessionLimitCache // 会话数量限制缓存(仅 Anthropic OAuth/SetupToken)
- userGroupRateCache *gocache.Cache
- userGroupRateSF singleflight.Group
- modelsListCache *gocache.Cache
- modelsListCacheTTL time.Duration
+ accountRepo AccountRepository
+ groupRepo GroupRepository
+ usageLogRepo UsageLogRepository
+ userRepo UserRepository
+ userSubRepo UserSubscriptionRepository
+ userGroupRateRepo UserGroupRateRepository
+ cache GatewayCache
+ digestStore *DigestSessionStore
+ cfg *config.Config
+ schedulerSnapshot *SchedulerSnapshotService
+ billingService *BillingService
+ rateLimitService *RateLimitService
+ billingCacheService *BillingCacheService
+ identityService *IdentityService
+ httpUpstream HTTPUpstream
+ deferredService *DeferredService
+ concurrencyService *ConcurrencyService
+ claudeTokenProvider *ClaudeTokenProvider
+ sessionLimitCache SessionLimitCache // 会话数量限制缓存(仅 Anthropic OAuth/SetupToken)
+ rpmCache RPMCache // RPM 计数缓存(仅 Anthropic OAuth/SetupToken)
+ userGroupRateCache *gocache.Cache
+ userGroupRateSF singleflight.Group
+ modelsListCache *gocache.Cache
+ modelsListCacheTTL time.Duration
+ responseHeaderFilter *responseheaders.CompiledHeaderFilter
+ debugModelRouting atomic.Bool
+ debugClaudeMimic atomic.Bool
}
// NewGatewayService creates a new GatewayService
@@ -554,35 +550,41 @@ func NewGatewayService(
deferredService *DeferredService,
claudeTokenProvider *ClaudeTokenProvider,
sessionLimitCache SessionLimitCache,
+ rpmCache RPMCache,
digestStore *DigestSessionStore,
) *GatewayService {
userGroupRateTTL := resolveUserGroupRateCacheTTL(cfg)
modelsListTTL := resolveModelsListCacheTTL(cfg)
- return &GatewayService{
- accountRepo: accountRepo,
- groupRepo: groupRepo,
- usageLogRepo: usageLogRepo,
- userRepo: userRepo,
- userSubRepo: userSubRepo,
- userGroupRateRepo: userGroupRateRepo,
- cache: cache,
- digestStore: digestStore,
- cfg: cfg,
- schedulerSnapshot: schedulerSnapshot,
- concurrencyService: concurrencyService,
- billingService: billingService,
- rateLimitService: rateLimitService,
- billingCacheService: billingCacheService,
- identityService: identityService,
- httpUpstream: httpUpstream,
- deferredService: deferredService,
- claudeTokenProvider: claudeTokenProvider,
- sessionLimitCache: sessionLimitCache,
- userGroupRateCache: gocache.New(userGroupRateTTL, time.Minute),
- modelsListCache: gocache.New(modelsListTTL, time.Minute),
- modelsListCacheTTL: modelsListTTL,
+ svc := &GatewayService{
+ accountRepo: accountRepo,
+ groupRepo: groupRepo,
+ usageLogRepo: usageLogRepo,
+ userRepo: userRepo,
+ userSubRepo: userSubRepo,
+ userGroupRateRepo: userGroupRateRepo,
+ cache: cache,
+ digestStore: digestStore,
+ cfg: cfg,
+ schedulerSnapshot: schedulerSnapshot,
+ concurrencyService: concurrencyService,
+ billingService: billingService,
+ rateLimitService: rateLimitService,
+ billingCacheService: billingCacheService,
+ identityService: identityService,
+ httpUpstream: httpUpstream,
+ deferredService: deferredService,
+ claudeTokenProvider: claudeTokenProvider,
+ sessionLimitCache: sessionLimitCache,
+ rpmCache: rpmCache,
+ userGroupRateCache: gocache.New(userGroupRateTTL, time.Minute),
+ modelsListCache: gocache.New(modelsListTTL, time.Minute),
+ modelsListCacheTTL: modelsListTTL,
+ responseHeaderFilter: compileResponseHeaderFilter(cfg),
}
+ svc.debugModelRouting.Store(parseDebugEnvBool(os.Getenv("SUB2API_DEBUG_MODEL_ROUTING")))
+ svc.debugClaudeMimic.Store(parseDebugEnvBool(os.Getenv("SUB2API_DEBUG_CLAUDE_MIMIC")))
+ return svc
}
// GenerateSessionHash 从预解析请求计算粘性会话 hash
@@ -1155,6 +1157,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
return nil, errors.New("no available accounts")
}
ctx = s.withWindowCostPrefetch(ctx, accounts)
+ ctx = s.withRPMPrefetch(ctx, accounts)
isExcluded := func(accountID int64) bool {
if excludedIDs == nil {
@@ -1204,7 +1207,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
continue
}
account, ok := accountByID[routingAccountID]
- if !ok || !account.IsSchedulable() {
+ if !ok || !s.isAccountSchedulableForSelection(account) {
if !ok {
filteredMissing++
} else {
@@ -1220,7 +1223,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
filteredModelMapping++
continue
}
- if !account.IsSchedulableForModelWithContext(ctx, requestedModel) {
+ if !s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) {
filteredModelScope++
modelScopeSkippedIDs = append(modelScopeSkippedIDs, account.ID)
continue
@@ -1230,6 +1233,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
filteredWindowCost++
continue
}
+ // RPM 检查(非粘性会话路径)
+ if !s.isAccountSchedulableForRPM(ctx, account, false) {
+ continue
+ }
routingCandidates = append(routingCandidates, account)
}
@@ -1249,11 +1256,13 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if containsInt64(routingAccountIDs, stickyAccountID) && !isExcluded(stickyAccountID) {
// 粘性账号在路由列表中,优先使用
if stickyAccount, ok := accountByID[stickyAccountID]; ok {
- if stickyAccount.IsSchedulable() &&
+ if s.isAccountSchedulableForSelection(stickyAccount) &&
s.isAccountAllowedForPlatform(stickyAccount, platform, useMixed) &&
(requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, stickyAccount, requestedModel)) &&
- stickyAccount.IsSchedulableForModelWithContext(ctx, requestedModel) &&
- s.isAccountSchedulableForWindowCost(ctx, stickyAccount, true) { // 粘性会话窗口费用检查
+ s.isAccountSchedulableForModelSelection(ctx, stickyAccount, requestedModel) &&
+ s.isAccountSchedulableForWindowCost(ctx, stickyAccount, true) &&
+
+ s.isAccountSchedulableForRPM(ctx, stickyAccount, true) { // 粘性会话窗口费用+RPM 检查
result, err := s.tryAcquireAccountSlot(ctx, stickyAccountID, stickyAccount.Concurrency)
if err == nil && result.Acquired {
// 会话数量限制检查
@@ -1406,8 +1415,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if !clearSticky && s.isAccountInGroup(account, groupID) &&
s.isAccountAllowedForPlatform(account, platform, useMixed) &&
(requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) &&
- account.IsSchedulableForModelWithContext(ctx, requestedModel) &&
- s.isAccountSchedulableForWindowCost(ctx, account, true) { // 粘性会话窗口费用检查
+ s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) &&
+ s.isAccountSchedulableForWindowCost(ctx, account, true) &&
+
+ s.isAccountSchedulableForRPM(ctx, account, true) { // 粘性会话窗口费用+RPM 检查
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if err == nil && result.Acquired {
// 会话数量限制检查
@@ -1457,7 +1468,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
// Scheduler snapshots can be temporarily stale (bucket rebuild is throttled);
// re-check schedulability here so recently rate-limited/overloaded accounts
// are not selected again before the bucket is rebuilt.
- if !acc.IsSchedulable() {
+ if !s.isAccountSchedulableForSelection(acc) {
continue
}
if !s.isAccountAllowedForPlatform(acc, platform, useMixed) {
@@ -1466,13 +1477,17 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) {
continue
}
- if !acc.IsSchedulableForModelWithContext(ctx, requestedModel) {
+ if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) {
continue
}
// 窗口费用检查(非粘性会话路径)
if !s.isAccountSchedulableForWindowCost(ctx, acc, false) {
continue
}
+ // RPM 检查(非粘性会话路径)
+ if !s.isAccountSchedulableForRPM(ctx, acc, false) {
+ continue
+ }
candidates = append(candidates, acc)
}
@@ -1737,6 +1752,9 @@ func (s *GatewayService) resolvePlatform(ctx context.Context, groupID *int64, gr
}
func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64, platform string, hasForcePlatform bool) ([]Account, bool, error) {
+ if platform == PlatformSora {
+ return s.listSoraSchedulableAccounts(ctx, groupID)
+ }
if s.schedulerSnapshot != nil {
accounts, useMixed, err := s.schedulerSnapshot.ListSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
if err == nil {
@@ -1831,6 +1849,53 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i
return accounts, useMixed, nil
}
+func (s *GatewayService) listSoraSchedulableAccounts(ctx context.Context, groupID *int64) ([]Account, bool, error) {
+ const useMixed = false
+
+ var accounts []Account
+ var err error
+ if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
+ accounts, err = s.accountRepo.ListByPlatform(ctx, PlatformSora)
+ } else if groupID != nil {
+ accounts, err = s.accountRepo.ListByGroup(ctx, *groupID)
+ } else {
+ accounts, err = s.accountRepo.ListByPlatform(ctx, PlatformSora)
+ }
+ if err != nil {
+ slog.Debug("account_scheduling_list_failed",
+ "group_id", derefGroupID(groupID),
+ "platform", PlatformSora,
+ "error", err)
+ return nil, useMixed, err
+ }
+
+ filtered := make([]Account, 0, len(accounts))
+ for _, acc := range accounts {
+ if acc.Platform != PlatformSora {
+ continue
+ }
+ if !s.isSoraAccountSchedulable(&acc) {
+ continue
+ }
+ filtered = append(filtered, acc)
+ }
+ slog.Debug("account_scheduling_list_sora",
+ "group_id", derefGroupID(groupID),
+ "platform", PlatformSora,
+ "raw_count", len(accounts),
+ "filtered_count", len(filtered))
+ for _, acc := range filtered {
+ slog.Debug("account_scheduling_account_detail",
+ "account_id", acc.ID,
+ "name", acc.Name,
+ "platform", acc.Platform,
+ "type", acc.Type,
+ "status", acc.Status,
+ "tls_fingerprint", acc.IsTLSFingerprintEnabled())
+ }
+ return filtered, useMixed, nil
+}
+
// IsSingleAntigravityAccountGroup 检查指定分组是否只有一个 antigravity 平台的可调度账号。
// 用于 Handler 层在首次请求时提前设置 SingleAccountRetry context,
// 避免单账号分组收到 503 时错误地设置模型限流标记导致后续请求连续快速失败。
@@ -1855,6 +1920,49 @@ func (s *GatewayService) isAccountAllowedForPlatform(account *Account, platform
return account.Platform == platform
}
+func (s *GatewayService) isSoraAccountSchedulable(account *Account) bool {
+ return s.soraUnschedulableReason(account) == ""
+}
+
+func (s *GatewayService) soraUnschedulableReason(account *Account) string {
+ if account == nil {
+ return "account_nil"
+ }
+ if account.Status != StatusActive {
+ return fmt.Sprintf("status=%s", account.Status)
+ }
+ if !account.Schedulable {
+ return "schedulable=false"
+ }
+ if account.TempUnschedulableUntil != nil && time.Now().Before(*account.TempUnschedulableUntil) {
+ return fmt.Sprintf("temp_unschedulable_until=%s", account.TempUnschedulableUntil.UTC().Format(time.RFC3339))
+ }
+ return ""
+}
+
+func (s *GatewayService) isAccountSchedulableForSelection(account *Account) bool {
+ if account == nil {
+ return false
+ }
+ if account.Platform == PlatformSora {
+ return s.isSoraAccountSchedulable(account)
+ }
+ return account.IsSchedulable()
+}
+
+func (s *GatewayService) isAccountSchedulableForModelSelection(ctx context.Context, account *Account, requestedModel string) bool {
+ if account == nil {
+ return false
+ }
+ if account.Platform == PlatformSora {
+ if !s.isSoraAccountSchedulable(account) {
+ return false
+ }
+ return account.GetRateLimitRemainingTimeWithContext(ctx, requestedModel) <= 0
+ }
+ return account.IsSchedulableForModelWithContext(ctx, requestedModel)
+}
+
// isAccountInGroup checks if the account belongs to the specified group.
// Returns true if groupID is nil (no group restriction) or account belongs to the group.
func (s *GatewayService) isAccountInGroup(account *Account, groupID *int64) bool {
@@ -2063,6 +2171,88 @@ checkSchedulability:
return true
}
+// rpmPrefetchContextKey is the context key for prefetched RPM counts.
+type rpmPrefetchContextKeyType struct{}
+
+var rpmPrefetchContextKey = rpmPrefetchContextKeyType{}
+
+func rpmFromPrefetchContext(ctx context.Context, accountID int64) (int, bool) {
+ if v, ok := ctx.Value(rpmPrefetchContextKey).(map[int64]int); ok {
+ count, found := v[accountID]
+ return count, found
+ }
+ return 0, false
+}
+
+// withRPMPrefetch 批量预取所有候选账号的 RPM 计数
+func (s *GatewayService) withRPMPrefetch(ctx context.Context, accounts []Account) context.Context {
+ if s.rpmCache == nil {
+ return ctx
+ }
+
+ var ids []int64
+ for i := range accounts {
+ if accounts[i].IsAnthropicOAuthOrSetupToken() && accounts[i].GetBaseRPM() > 0 {
+ ids = append(ids, accounts[i].ID)
+ }
+ }
+ if len(ids) == 0 {
+ return ctx
+ }
+
+ counts, err := s.rpmCache.GetRPMBatch(ctx, ids)
+ if err != nil {
+ return ctx // 失败开放
+ }
+ return context.WithValue(ctx, rpmPrefetchContextKey, counts)
+}
+
+// isAccountSchedulableForRPM 检查账号是否可根据 RPM 进行调度
+// 仅适用于 Anthropic OAuth/SetupToken 账号
+func (s *GatewayService) isAccountSchedulableForRPM(ctx context.Context, account *Account, isSticky bool) bool {
+ if !account.IsAnthropicOAuthOrSetupToken() {
+ return true
+ }
+ baseRPM := account.GetBaseRPM()
+ if baseRPM <= 0 {
+ return true
+ }
+
+ // 尝试从预取缓存获取
+ var currentRPM int
+ if count, ok := rpmFromPrefetchContext(ctx, account.ID); ok {
+ currentRPM = count
+ } else if s.rpmCache != nil {
+ if count, err := s.rpmCache.GetRPM(ctx, account.ID); err == nil {
+ currentRPM = count
+ }
+ // 失败开放:GetRPM 错误时允许调度
+ }
+
+ schedulability := account.CheckRPMSchedulability(currentRPM)
+ switch schedulability {
+ case WindowCostSchedulable:
+ return true
+ case WindowCostStickyOnly:
+ return isSticky
+ case WindowCostNotSchedulable:
+ return false
+ }
+ return true
+}
+
+// IncrementAccountRPM increments the RPM counter for the given account.
+// 已知 TOCTOU 竞态:调度时读取 RPM 计数与此处递增之间存在时间窗口,
+// 高并发下可能短暂超出 RPM 限制。这是与 WindowCost 一致的 soft-limit
+// 设计权衡——可接受的少量超额优于加锁带来的延迟和复杂度。
+func (s *GatewayService) IncrementAccountRPM(ctx context.Context, accountID int64) error {
+ if s.rpmCache == nil {
+ return nil
+ }
+ _, err := s.rpmCache.IncrementRPM(ctx, accountID)
+ return err
+}
+
// checkAndRegisterSession 检查并注册会话,用于会话数量限制
// 仅适用于 Anthropic OAuth/SetupToken 账号
// sessionID: 会话标识符(使用粘性会话的 hash)
@@ -2257,7 +2447,7 @@ func sameAccountWithLoadGroup(a, b accountWithLoad) bool {
// shuffleWithinPriorityAndLastUsed 对排序后的 []*Account 切片,按 (Priority, LastUsedAt) 分组后组内随机打乱。
//
// 注意:当 preferOAuth=true 时,需要保证 OAuth 账号在同组内仍然优先,否则会把排序时的偏好打散掉。
-// 因此这里采用“组内分区 + 分区内 shuffle”的方式:
+// 因此这里采用"组内分区 + 分区内 shuffle"的方式:
// - 先把同组账号按 (OAuth / 非 OAuth) 拆成两段,保持 OAuth 段在前;
// - 再分别在各段内随机打散,避免热点。
func shuffleWithinPriorityAndLastUsed(accounts []*Account, preferOAuth bool) {
@@ -2397,7 +2587,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
- if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && account.IsSchedulableForModelWithContext(ctx, requestedModel) {
+ if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) {
if s.debugModelRoutingEnabled() {
logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
}
@@ -2420,6 +2610,10 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
}
accountsLoaded = true
+ // 提前预取窗口费用+RPM 计数,确保 routing 段内的调度检查调用能命中缓存
+ ctx = s.withWindowCostPrefetch(ctx, accounts)
+ ctx = s.withRPMPrefetch(ctx, accounts)
+
routingSet := make(map[int64]struct{}, len(routingAccountIDs))
for _, id := range routingAccountIDs {
if id > 0 {
@@ -2438,13 +2632,19 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
}
// Scheduler snapshots can be temporarily stale; re-check schedulability here to
// avoid selecting accounts that were recently rate-limited/overloaded.
- if !acc.IsSchedulable() {
+ if !s.isAccountSchedulableForSelection(acc) {
continue
}
if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) {
continue
}
- if !acc.IsSchedulableForModelWithContext(ctx, requestedModel) {
+ if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) {
+ continue
+ }
+ if !s.isAccountSchedulableForWindowCost(ctx, acc, false) {
+ continue
+ }
+ if !s.isAccountSchedulableForRPM(ctx, acc, false) {
continue
}
if selected == nil {
@@ -2497,7 +2697,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
- if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && account.IsSchedulableForModelWithContext(ctx, requestedModel) {
+ if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) {
return account, nil
}
}
@@ -2518,6 +2718,10 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
}
}
+ // 批量预取窗口费用+RPM 计数,避免逐个账号查询(N+1)
+ ctx = s.withWindowCostPrefetch(ctx, accounts)
+ ctx = s.withRPMPrefetch(ctx, accounts)
+
// 3. 按优先级+最久未用选择(考虑模型支持)
var selected *Account
for i := range accounts {
@@ -2527,13 +2731,19 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
}
// Scheduler snapshots can be temporarily stale; re-check schedulability here to
// avoid selecting accounts that were recently rate-limited/overloaded.
- if !acc.IsSchedulable() {
+ if !s.isAccountSchedulableForSelection(acc) {
continue
}
if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) {
continue
}
- if !acc.IsSchedulableForModelWithContext(ctx, requestedModel) {
+ if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) {
+ continue
+ }
+ if !s.isAccountSchedulableForWindowCost(ctx, acc, false) {
+ continue
+ }
+ if !s.isAccountSchedulableForRPM(ctx, acc, false) {
continue
}
if selected == nil {
@@ -2561,8 +2771,9 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
}
if selected == nil {
+ stats := s.logDetailedSelectionFailure(ctx, groupID, sessionHash, requestedModel, platform, accounts, excludedIDs, false)
if requestedModel != "" {
- return nil, fmt.Errorf("no available accounts supporting model: %s", requestedModel)
+ return nil, fmt.Errorf("no available accounts supporting model: %s (%s)", requestedModel, summarizeSelectionFailureStats(stats))
}
return nil, errors.New("no available accounts")
}
@@ -2604,7 +2815,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
- if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && account.IsSchedulableForModelWithContext(ctx, requestedModel) {
+ if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) {
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
if s.debugModelRoutingEnabled() {
logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy mixed routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
@@ -2625,6 +2836,10 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
}
accountsLoaded = true
+ // 提前预取窗口费用+RPM 计数,确保 routing 段内的调度检查调用能命中缓存
+ ctx = s.withWindowCostPrefetch(ctx, accounts)
+ ctx = s.withRPMPrefetch(ctx, accounts)
+
routingSet := make(map[int64]struct{}, len(routingAccountIDs))
for _, id := range routingAccountIDs {
if id > 0 {
@@ -2643,7 +2858,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
}
// Scheduler snapshots can be temporarily stale; re-check schedulability here to
// avoid selecting accounts that were recently rate-limited/overloaded.
- if !acc.IsSchedulable() {
+ if !s.isAccountSchedulableForSelection(acc) {
continue
}
// 过滤:原生平台直接通过,antigravity 需要启用混合调度
@@ -2653,7 +2868,13 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) {
continue
}
- if !acc.IsSchedulableForModelWithContext(ctx, requestedModel) {
+ if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) {
+ continue
+ }
+ if !s.isAccountSchedulableForWindowCost(ctx, acc, false) {
+ continue
+ }
+ if !s.isAccountSchedulableForRPM(ctx, acc, false) {
continue
}
if selected == nil {
@@ -2706,7 +2927,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
- if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && account.IsSchedulableForModelWithContext(ctx, requestedModel) {
+ if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) {
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
return account, nil
}
@@ -2725,6 +2946,10 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
}
}
+ // 批量预取窗口费用+RPM 计数,避免逐个账号查询(N+1)
+ ctx = s.withWindowCostPrefetch(ctx, accounts)
+ ctx = s.withRPMPrefetch(ctx, accounts)
+
// 3. 按优先级+最久未用选择(考虑模型支持和混合调度)
var selected *Account
for i := range accounts {
@@ -2734,7 +2959,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
}
// Scheduler snapshots can be temporarily stale; re-check schedulability here to
// avoid selecting accounts that were recently rate-limited/overloaded.
- if !acc.IsSchedulable() {
+ if !s.isAccountSchedulableForSelection(acc) {
continue
}
// 过滤:原生平台直接通过,antigravity 需要启用混合调度
@@ -2744,7 +2969,13 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) {
continue
}
- if !acc.IsSchedulableForModelWithContext(ctx, requestedModel) {
+ if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) {
+ continue
+ }
+ if !s.isAccountSchedulableForWindowCost(ctx, acc, false) {
+ continue
+ }
+ if !s.isAccountSchedulableForRPM(ctx, acc, false) {
continue
}
if selected == nil {
@@ -2772,8 +3003,9 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
}
if selected == nil {
+ stats := s.logDetailedSelectionFailure(ctx, groupID, sessionHash, requestedModel, nativePlatform, accounts, excludedIDs, true)
if requestedModel != "" {
- return nil, fmt.Errorf("no available accounts supporting model: %s", requestedModel)
+ return nil, fmt.Errorf("no available accounts supporting model: %s (%s)", requestedModel, summarizeSelectionFailureStats(stats))
}
return nil, errors.New("no available accounts")
}
@@ -2788,6 +3020,236 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
return selected, nil
}
+type selectionFailureStats struct {
+ Total int
+ Eligible int
+ Excluded int
+ Unschedulable int
+ PlatformFiltered int
+ ModelUnsupported int
+ ModelRateLimited int
+ SamplePlatformIDs []int64
+ SampleMappingIDs []int64
+ SampleRateLimitIDs []string
+}
+
+type selectionFailureDiagnosis struct {
+ Category string
+ Detail string
+}
+
+func (s *GatewayService) logDetailedSelectionFailure(
+ ctx context.Context,
+ groupID *int64,
+ sessionHash string,
+ requestedModel string,
+ platform string,
+ accounts []Account,
+ excludedIDs map[int64]struct{},
+ allowMixedScheduling bool,
+) selectionFailureStats {
+ stats := s.collectSelectionFailureStats(ctx, accounts, requestedModel, platform, excludedIDs, allowMixedScheduling)
+ logger.LegacyPrintf(
+ "service.gateway",
+ "[SelectAccountDetailed] group_id=%v model=%s platform=%s session=%s total=%d eligible=%d excluded=%d unschedulable=%d platform_filtered=%d model_unsupported=%d model_rate_limited=%d sample_platform_filtered=%v sample_model_unsupported=%v sample_model_rate_limited=%v",
+ derefGroupID(groupID),
+ requestedModel,
+ platform,
+ shortSessionHash(sessionHash),
+ stats.Total,
+ stats.Eligible,
+ stats.Excluded,
+ stats.Unschedulable,
+ stats.PlatformFiltered,
+ stats.ModelUnsupported,
+ stats.ModelRateLimited,
+ stats.SamplePlatformIDs,
+ stats.SampleMappingIDs,
+ stats.SampleRateLimitIDs,
+ )
+ if platform == PlatformSora {
+ s.logSoraSelectionFailureDetails(ctx, groupID, sessionHash, requestedModel, accounts, excludedIDs, allowMixedScheduling)
+ }
+ return stats
+}
+
+func (s *GatewayService) collectSelectionFailureStats(
+ ctx context.Context,
+ accounts []Account,
+ requestedModel string,
+ platform string,
+ excludedIDs map[int64]struct{},
+ allowMixedScheduling bool,
+) selectionFailureStats {
+ stats := selectionFailureStats{
+ Total: len(accounts),
+ }
+
+ for i := range accounts {
+ acc := &accounts[i]
+ diagnosis := s.diagnoseSelectionFailure(ctx, acc, requestedModel, platform, excludedIDs, allowMixedScheduling)
+ switch diagnosis.Category {
+ case "excluded":
+ stats.Excluded++
+ case "unschedulable":
+ stats.Unschedulable++
+ case "platform_filtered":
+ stats.PlatformFiltered++
+ stats.SamplePlatformIDs = appendSelectionFailureSampleID(stats.SamplePlatformIDs, acc.ID)
+ case "model_unsupported":
+ stats.ModelUnsupported++
+ stats.SampleMappingIDs = appendSelectionFailureSampleID(stats.SampleMappingIDs, acc.ID)
+ case "model_rate_limited":
+ stats.ModelRateLimited++
+ remaining := acc.GetRateLimitRemainingTimeWithContext(ctx, requestedModel).Truncate(time.Second)
+ stats.SampleRateLimitIDs = appendSelectionFailureRateSample(stats.SampleRateLimitIDs, acc.ID, remaining)
+ default:
+ stats.Eligible++
+ }
+ }
+
+ return stats
+}
+
+func (s *GatewayService) diagnoseSelectionFailure(
+ ctx context.Context,
+ acc *Account,
+ requestedModel string,
+ platform string,
+ excludedIDs map[int64]struct{},
+ allowMixedScheduling bool,
+) selectionFailureDiagnosis {
+ if acc == nil {
+ return selectionFailureDiagnosis{Category: "unschedulable", Detail: "account_nil"}
+ }
+ if _, excluded := excludedIDs[acc.ID]; excluded {
+ return selectionFailureDiagnosis{Category: "excluded"}
+ }
+ if !s.isAccountSchedulableForSelection(acc) {
+ detail := "generic_unschedulable"
+ if acc.Platform == PlatformSora {
+ detail = s.soraUnschedulableReason(acc)
+ }
+ return selectionFailureDiagnosis{Category: "unschedulable", Detail: detail}
+ }
+ if isPlatformFilteredForSelection(acc, platform, allowMixedScheduling) {
+ return selectionFailureDiagnosis{
+ Category: "platform_filtered",
+ Detail: fmt.Sprintf("account_platform=%s requested_platform=%s", acc.Platform, strings.TrimSpace(platform)),
+ }
+ }
+ if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) {
+ return selectionFailureDiagnosis{
+ Category: "model_unsupported",
+ Detail: fmt.Sprintf("model=%s", requestedModel),
+ }
+ }
+ if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) {
+ remaining := acc.GetRateLimitRemainingTimeWithContext(ctx, requestedModel).Truncate(time.Second)
+ return selectionFailureDiagnosis{
+ Category: "model_rate_limited",
+ Detail: fmt.Sprintf("remaining=%s", remaining),
+ }
+ }
+ return selectionFailureDiagnosis{Category: "eligible"}
+}
+
+func (s *GatewayService) logSoraSelectionFailureDetails(
+ ctx context.Context,
+ groupID *int64,
+ sessionHash string,
+ requestedModel string,
+ accounts []Account,
+ excludedIDs map[int64]struct{},
+ allowMixedScheduling bool,
+) {
+ const maxLines = 30
+ logged := 0
+
+ for i := range accounts {
+ if logged >= maxLines {
+ break
+ }
+ acc := &accounts[i]
+ diagnosis := s.diagnoseSelectionFailure(ctx, acc, requestedModel, PlatformSora, excludedIDs, allowMixedScheduling)
+ if diagnosis.Category == "eligible" {
+ continue
+ }
+ detail := diagnosis.Detail
+ if detail == "" {
+ detail = "-"
+ }
+ logger.LegacyPrintf(
+ "service.gateway",
+ "[SelectAccountDetailed:Sora] group_id=%v model=%s session=%s account_id=%d account_platform=%s category=%s detail=%s",
+ derefGroupID(groupID),
+ requestedModel,
+ shortSessionHash(sessionHash),
+ acc.ID,
+ acc.Platform,
+ diagnosis.Category,
+ detail,
+ )
+ logged++
+ }
+ if len(accounts) > maxLines {
+ logger.LegacyPrintf(
+ "service.gateway",
+ "[SelectAccountDetailed:Sora] group_id=%v model=%s session=%s truncated=true total=%d logged=%d",
+ derefGroupID(groupID),
+ requestedModel,
+ shortSessionHash(sessionHash),
+ len(accounts),
+ logged,
+ )
+ }
+}
+
+func isPlatformFilteredForSelection(acc *Account, platform string, allowMixedScheduling bool) bool {
+ if acc == nil {
+ return true
+ }
+ if allowMixedScheduling {
+ if acc.Platform == PlatformAntigravity {
+ return !acc.IsMixedSchedulingEnabled()
+ }
+ return acc.Platform != platform
+ }
+ if strings.TrimSpace(platform) == "" {
+ return false
+ }
+ return acc.Platform != platform
+}
+
+func appendSelectionFailureSampleID(samples []int64, id int64) []int64 {
+ const limit = 5
+ if len(samples) >= limit {
+ return samples
+ }
+ return append(samples, id)
+}
+
+func appendSelectionFailureRateSample(samples []string, accountID int64, remaining time.Duration) []string {
+ const limit = 5
+ if len(samples) >= limit {
+ return samples
+ }
+ return append(samples, fmt.Sprintf("%d(%s)", accountID, remaining))
+}
+
+func summarizeSelectionFailureStats(stats selectionFailureStats) string {
+ return fmt.Sprintf(
+ "total=%d eligible=%d excluded=%d unschedulable=%d platform_filtered=%d model_unsupported=%d model_rate_limited=%d",
+ stats.Total,
+ stats.Eligible,
+ stats.Excluded,
+ stats.Unschedulable,
+ stats.PlatformFiltered,
+ stats.ModelUnsupported,
+ stats.ModelRateLimited,
+ )
+}
+
// isModelSupportedByAccountWithContext 根据账户平台检查模型支持(带 context)
// 对于 Antigravity 平台,会先获取映射后的最终模型名(包括 thinking 后缀)再检查支持
func (s *GatewayService) isModelSupportedByAccountWithContext(ctx context.Context, account *Account, requestedModel string) bool {
@@ -2801,7 +3263,7 @@ func (s *GatewayService) isModelSupportedByAccountWithContext(ctx context.Contex
return false
}
// 应用 thinking 后缀后检查最终模型是否在账号映射中
- if enabled, ok := ctx.Value(ctxkey.ThinkingEnabled).(bool); ok {
+ if enabled, ok := ThinkingEnabledFromContext(ctx); ok {
finalModel := applyThinkingModelSuffix(mapped, enabled)
if finalModel == mapped {
return true // thinking 后缀未改变模型名,映射已通过
@@ -2821,6 +3283,9 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo
}
return mapAntigravityModel(account, requestedModel) != ""
}
+ if account.Platform == PlatformSora {
+ return s.isSoraModelSupportedByAccount(account, requestedModel)
+ }
// OAuth/SetupToken 账号使用 Anthropic 标准映射(短ID → 长ID)
if account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
requestedModel = claude.NormalizeModelID(requestedModel)
@@ -2829,6 +3294,143 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo
return account.IsModelSupported(requestedModel)
}
+func (s *GatewayService) isSoraModelSupportedByAccount(account *Account, requestedModel string) bool {
+ if account == nil {
+ return false
+ }
+ if strings.TrimSpace(requestedModel) == "" {
+ return true
+ }
+
+ // 先走原始精确/通配符匹配。
+ mapping := account.GetModelMapping()
+ if len(mapping) == 0 || account.IsModelSupported(requestedModel) {
+ return true
+ }
+
+ aliases := buildSoraModelAliases(requestedModel)
+ if len(aliases) == 0 {
+ return false
+ }
+
+ hasSoraSelector := false
+ for pattern := range mapping {
+ if !isSoraModelSelector(pattern) {
+ continue
+ }
+ hasSoraSelector = true
+ if matchPatternAnyAlias(pattern, aliases) {
+ return true
+ }
+ }
+
+ // 兼容旧账号:mapping 存在但未配置任何 Sora 选择器(例如只含 gpt-*),
+ // 此时不应误拦截 Sora 模型请求。
+ if !hasSoraSelector {
+ return true
+ }
+
+ return false
+}
+
+func matchPatternAnyAlias(pattern string, aliases []string) bool {
+ normalizedPattern := strings.ToLower(strings.TrimSpace(pattern))
+ if normalizedPattern == "" {
+ return false
+ }
+ for _, alias := range aliases {
+ if matchWildcard(normalizedPattern, alias) {
+ return true
+ }
+ }
+ return false
+}
+
+func isSoraModelSelector(pattern string) bool {
+ p := strings.ToLower(strings.TrimSpace(pattern))
+ if p == "" {
+ return false
+ }
+
+ switch {
+ case strings.HasPrefix(p, "sora"),
+ strings.HasPrefix(p, "gpt-image"),
+ strings.HasPrefix(p, "prompt-enhance"),
+ strings.HasPrefix(p, "sy_"):
+ return true
+ }
+
+ return p == "video" || p == "image"
+}
+
+func buildSoraModelAliases(requestedModel string) []string {
+ modelID := strings.ToLower(strings.TrimSpace(requestedModel))
+ if modelID == "" {
+ return nil
+ }
+
+ aliases := make([]string, 0, 8)
+ addAlias := func(value string) {
+ v := strings.ToLower(strings.TrimSpace(value))
+ if v == "" {
+ return
+ }
+ for _, existing := range aliases {
+ if existing == v {
+ return
+ }
+ }
+ aliases = append(aliases, v)
+ }
+
+ addAlias(modelID)
+ cfg, ok := GetSoraModelConfig(modelID)
+ if ok {
+ addAlias(cfg.Model)
+ switch cfg.Type {
+ case "video":
+ addAlias("video")
+ addAlias("sora")
+ addAlias(soraVideoFamilyAlias(modelID))
+ case "image":
+ addAlias("image")
+ addAlias("gpt-image")
+ case "prompt_enhance":
+ addAlias("prompt-enhance")
+ }
+ return aliases
+ }
+
+ switch {
+ case strings.HasPrefix(modelID, "sora"):
+ addAlias("video")
+ addAlias("sora")
+ addAlias(soraVideoFamilyAlias(modelID))
+ case strings.HasPrefix(modelID, "gpt-image"):
+ addAlias("image")
+ addAlias("gpt-image")
+ case strings.HasPrefix(modelID, "prompt-enhance"):
+ addAlias("prompt-enhance")
+ default:
+ return nil
+ }
+
+ return aliases
+}
+
+func soraVideoFamilyAlias(modelID string) string {
+ switch {
+ case strings.HasPrefix(modelID, "sora2pro-hd"):
+ return "sora2pro-hd"
+ case strings.HasPrefix(modelID, "sora2pro"):
+ return "sora2pro"
+ case strings.HasPrefix(modelID, "sora2"):
+ return "sora2"
+ default:
+ return ""
+ }
+}
+
// GetAccessToken 获取账号凭证
func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) {
switch account.Type {
@@ -4012,7 +4614,7 @@ func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough(
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
}
- writeAnthropicPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.cfg)
+ writeAnthropicPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
contentType := strings.TrimSpace(resp.Header.Get("Content-Type"))
if contentType == "" {
@@ -4308,7 +4910,7 @@ func (s *GatewayService) handleNonStreamingResponseAnthropicAPIKeyPassthrough(
usage := parseClaudeUsageFromResponseBody(body)
- writeAnthropicPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.cfg)
+ writeAnthropicPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
contentType := strings.TrimSpace(resp.Header.Get("Content-Type"))
if contentType == "" {
contentType = "application/json"
@@ -4317,12 +4919,12 @@ func (s *GatewayService) handleNonStreamingResponseAnthropicAPIKeyPassthrough(
return usage, nil
}
-func writeAnthropicPassthroughResponseHeaders(dst http.Header, src http.Header, cfg *config.Config) {
+func writeAnthropicPassthroughResponseHeaders(dst http.Header, src http.Header, filter *responseheaders.CompiledHeaderFilter) {
if dst == nil || src == nil {
return
}
- if cfg != nil {
- responseheaders.WriteFilteredHeaders(dst, src, cfg.Security.ResponseHeaders)
+ if filter != nil {
+ responseheaders.WriteFilteredHeaders(dst, src, filter)
return
}
if v := strings.TrimSpace(src.Get("Content-Type")); v != "" {
@@ -4425,12 +5027,11 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
// messages requests typically use only oauth + interleaved-thinking.
// Also drop claude-code beta if a downstream client added it.
requiredBetas := []string{claude.BetaOAuth, claude.BetaInterleavedThinking}
- drop := map[string]struct{}{claude.BetaClaudeCode: {}, claude.BetaContext1M: {}}
- req.Header.Set("anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, drop))
+ req.Header.Set("anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, droppedBetasWithClaudeCodeSet))
} else {
// Claude Code 客户端:尽量透传原始 header,仅补齐 oauth beta
clientBetaHeader := req.Header.Get("anthropic-beta")
- req.Header.Set("anthropic-beta", stripBetaToken(s.getBetaHeader(modelID, clientBetaHeader), claude.BetaContext1M))
+ req.Header.Set("anthropic-beta", stripBetaTokensWithSet(s.getBetaHeader(modelID, clientBetaHeader), defaultDroppedBetasSet))
}
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" {
// API-key:仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭)
@@ -4584,23 +5185,64 @@ func mergeAnthropicBetaDropping(required []string, incoming string, drop map[str
return strings.Join(out, ",")
}
-// stripBetaToken removes a single beta token from a comma-separated header value.
-// It short-circuits when the token is not present to avoid unnecessary allocations.
-func stripBetaToken(header, token string) string {
- if !strings.Contains(header, token) {
+// stripBetaTokens removes the given beta tokens from a comma-separated header value.
+func stripBetaTokens(header string, tokens []string) string {
+ if header == "" || len(tokens) == 0 {
return header
}
- out := make([]string, 0, 8)
- for _, p := range strings.Split(header, ",") {
+ return stripBetaTokensWithSet(header, buildBetaTokenSet(tokens))
+}
+
+func stripBetaTokensWithSet(header string, drop map[string]struct{}) string {
+ if header == "" || len(drop) == 0 {
+ return header
+ }
+ parts := strings.Split(header, ",")
+ out := make([]string, 0, len(parts))
+ for _, p := range parts {
p = strings.TrimSpace(p)
- if p == "" || p == token {
+ if p == "" {
+ continue
+ }
+ if _, ok := drop[p]; ok {
continue
}
out = append(out, p)
}
+ if len(out) == len(parts) {
+ return header // no change, avoid allocation
+ }
return strings.Join(out, ",")
}
+// droppedBetaSet returns claude.DroppedBetas as a set, with optional extra tokens.
+func droppedBetaSet(extra ...string) map[string]struct{} {
+ m := make(map[string]struct{}, len(defaultDroppedBetasSet)+len(extra))
+ for t := range defaultDroppedBetasSet {
+ m[t] = struct{}{}
+ }
+ for _, t := range extra {
+ m[t] = struct{}{}
+ }
+ return m
+}
+
+func buildBetaTokenSet(tokens []string) map[string]struct{} {
+ m := make(map[string]struct{}, len(tokens))
+ for _, t := range tokens {
+ if t == "" {
+ continue
+ }
+ m[t] = struct{}{}
+ }
+ return m
+}
+
+var (
+ defaultDroppedBetasSet = buildBetaTokenSet(claude.DroppedBetas)
+ droppedBetasWithClaudeCodeSet = droppedBetaSet(claude.BetaClaudeCode)
+)
+
// applyClaudeCodeMimicHeaders forces "Claude Code-like" request headers.
// This mirrors opencode-anthropic-auth behavior: do not trust downstream
// headers when using Claude Code-scoped OAuth credentials.
@@ -4681,7 +5323,7 @@ func (s *GatewayService) isThinkingBlockSignatureError(respBody []byte) bool {
}
func (s *GatewayService) shouldFailoverOn400(respBody []byte) bool {
- // 只对“可能是兼容性差异导致”的 400 允许切换,避免无意义重试。
+ // 只对"可能是兼容性差异导致"的 400 允许切换,避免无意义重试。
// 默认保守:无法识别则不切换。
msg := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(respBody)))
if msg == "" {
@@ -4730,6 +5372,20 @@ func extractUpstreamErrorMessage(body []byte) string {
return gjson.GetBytes(body, "message").String()
}
+func isCountTokensUnsupported404(statusCode int, body []byte) bool {
+ if statusCode != http.StatusNotFound {
+ return false
+ }
+ msg := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(body)))
+ if msg == "" {
+ return false
+ }
+ if strings.Contains(msg, "/v1/messages/count_tokens") {
+ return true
+ }
+ return strings.Contains(msg, "count_tokens") && strings.Contains(msg, "not found")
+}
+
func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
@@ -5007,8 +5663,8 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
// 更新5h窗口状态
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
- if s.cfg != nil {
- responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
+ if s.responseHeaderFilter != nil {
+ responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
}
// 设置SSE响应头
@@ -5102,9 +5758,9 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
pendingEventLines := make([]string, 0, 4)
- processSSEEvent := func(lines []string) ([]string, string, error) {
+ processSSEEvent := func(lines []string) ([]string, string, *sseUsagePatch, error) {
if len(lines) == 0 {
- return nil, "", nil
+ return nil, "", nil, nil
}
eventName := ""
@@ -5121,11 +5777,11 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
}
if eventName == "error" {
- return nil, dataLine, errors.New("have error in stream")
+ return nil, dataLine, nil, errors.New("have error in stream")
}
if dataLine == "" {
- return []string{strings.Join(lines, "\n") + "\n\n"}, "", nil
+ return []string{strings.Join(lines, "\n") + "\n\n"}, "", nil, nil
}
if dataLine == "[DONE]" {
@@ -5134,7 +5790,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
block = "event: " + eventName + "\n"
}
block += "data: " + dataLine + "\n\n"
- return []string{block}, dataLine, nil
+ return []string{block}, dataLine, nil, nil
}
var event map[string]any
@@ -5145,25 +5801,26 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
block = "event: " + eventName + "\n"
}
block += "data: " + dataLine + "\n\n"
- return []string{block}, dataLine, nil
+ return []string{block}, dataLine, nil, nil
}
eventType, _ := event["type"].(string)
if eventName == "" {
eventName = eventType
}
+ eventChanged := false
// 兼容 Kimi cached_tokens → cache_read_input_tokens
if eventType == "message_start" {
if msg, ok := event["message"].(map[string]any); ok {
if u, ok := msg["usage"].(map[string]any); ok {
- reconcileCachedTokens(u)
+ eventChanged = reconcileCachedTokens(u) || eventChanged
}
}
}
if eventType == "message_delta" {
if u, ok := event["usage"].(map[string]any); ok {
- reconcileCachedTokens(u)
+ eventChanged = reconcileCachedTokens(u) || eventChanged
}
}
@@ -5173,13 +5830,13 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
if eventType == "message_start" {
if msg, ok := event["message"].(map[string]any); ok {
if u, ok := msg["usage"].(map[string]any); ok {
- rewriteCacheCreationJSON(u, overrideTarget)
+ eventChanged = rewriteCacheCreationJSON(u, overrideTarget) || eventChanged
}
}
}
if eventType == "message_delta" {
if u, ok := event["usage"].(map[string]any); ok {
- rewriteCacheCreationJSON(u, overrideTarget)
+ eventChanged = rewriteCacheCreationJSON(u, overrideTarget) || eventChanged
}
}
}
@@ -5188,10 +5845,21 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
if msg, ok := event["message"].(map[string]any); ok {
if model, ok := msg["model"].(string); ok && model == mappedModel {
msg["model"] = originalModel
+ eventChanged = true
}
}
}
+ usagePatch := s.extractSSEUsagePatch(event)
+ if !eventChanged {
+ block := ""
+ if eventName != "" {
+ block = "event: " + eventName + "\n"
+ }
+ block += "data: " + dataLine + "\n\n"
+ return []string{block}, dataLine, usagePatch, nil
+ }
+
newData, err := json.Marshal(event)
if err != nil {
// 序列化失败,直接透传原始数据
@@ -5200,7 +5868,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
block = "event: " + eventName + "\n"
}
block += "data: " + dataLine + "\n\n"
- return []string{block}, dataLine, nil
+ return []string{block}, dataLine, usagePatch, nil
}
block := ""
@@ -5208,7 +5876,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
block = "event: " + eventName + "\n"
}
block += "data: " + string(newData) + "\n\n"
- return []string{block}, string(newData), nil
+ return []string{block}, string(newData), usagePatch, nil
}
for {
@@ -5246,7 +5914,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
continue
}
- outputBlocks, data, err := processSSEEvent(pendingEventLines)
+ outputBlocks, data, usagePatch, err := processSSEEvent(pendingEventLines)
pendingEventLines = pendingEventLines[:0]
if err != nil {
if clientDisconnected {
@@ -5269,7 +5937,9 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
- s.parseSSEUsage(data, usage)
+ if usagePatch != nil {
+ mergeSSEUsagePatch(usage, usagePatch)
+ }
}
}
continue
@@ -5300,64 +5970,163 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
}
func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) {
- // 解析message_start获取input tokens(标准Claude API格式)
- var msgStart struct {
- Type string `json:"type"`
- Message struct {
- Usage ClaudeUsage `json:"usage"`
- } `json:"message"`
- }
- if json.Unmarshal([]byte(data), &msgStart) == nil && msgStart.Type == "message_start" {
- usage.InputTokens = msgStart.Message.Usage.InputTokens
- usage.CacheCreationInputTokens = msgStart.Message.Usage.CacheCreationInputTokens
- usage.CacheReadInputTokens = msgStart.Message.Usage.CacheReadInputTokens
-
- // 解析嵌套的 cache_creation 对象中的 5m/1h 明细
- cc5m := gjson.Get(data, "message.usage.cache_creation.ephemeral_5m_input_tokens")
- cc1h := gjson.Get(data, "message.usage.cache_creation.ephemeral_1h_input_tokens")
- if cc5m.Exists() || cc1h.Exists() {
- usage.CacheCreation5mTokens = int(cc5m.Int())
- usage.CacheCreation1hTokens = int(cc1h.Int())
- }
+ if usage == nil {
+ return
}
- // 解析message_delta获取tokens(兼容GLM等把所有usage放在delta中的API)
- var msgDelta struct {
- Type string `json:"type"`
- Usage struct {
- InputTokens int `json:"input_tokens"`
- OutputTokens int `json:"output_tokens"`
- CacheCreationInputTokens int `json:"cache_creation_input_tokens"`
- CacheReadInputTokens int `json:"cache_read_input_tokens"`
- } `json:"usage"`
+ var event map[string]any
+ if err := json.Unmarshal([]byte(data), &event); err != nil {
+ return
}
- if json.Unmarshal([]byte(data), &msgDelta) == nil && msgDelta.Type == "message_delta" {
- // message_delta 仅覆盖存在且非0的字段
- // 避免覆盖 message_start 中已有的值(如 input_tokens)
- // Claude API 的 message_delta 通常只包含 output_tokens
- if msgDelta.Usage.InputTokens > 0 {
- usage.InputTokens = msgDelta.Usage.InputTokens
- }
- if msgDelta.Usage.OutputTokens > 0 {
- usage.OutputTokens = msgDelta.Usage.OutputTokens
- }
- if msgDelta.Usage.CacheCreationInputTokens > 0 {
- usage.CacheCreationInputTokens = msgDelta.Usage.CacheCreationInputTokens
- }
- if msgDelta.Usage.CacheReadInputTokens > 0 {
- usage.CacheReadInputTokens = msgDelta.Usage.CacheReadInputTokens
+
+ if patch := s.extractSSEUsagePatch(event); patch != nil {
+ mergeSSEUsagePatch(usage, patch)
+ }
+}
+
+type sseUsagePatch struct {
+ inputTokens int
+ hasInputTokens bool
+ outputTokens int
+ hasOutputTokens bool
+ cacheCreationInputTokens int
+ hasCacheCreationInput bool
+ cacheReadInputTokens int
+ hasCacheReadInput bool
+ cacheCreation5mTokens int
+ hasCacheCreation5m bool
+ cacheCreation1hTokens int
+ hasCacheCreation1h bool
+}
+
+func (s *GatewayService) extractSSEUsagePatch(event map[string]any) *sseUsagePatch {
+ if len(event) == 0 {
+ return nil
+ }
+
+ eventType, _ := event["type"].(string)
+ switch eventType {
+ case "message_start":
+ msg, _ := event["message"].(map[string]any)
+ usageObj, _ := msg["usage"].(map[string]any)
+ if len(usageObj) == 0 {
+ return nil
}
- // 解析嵌套的 cache_creation 对象中的 5m/1h 明细
- cc5m := gjson.Get(data, "usage.cache_creation.ephemeral_5m_input_tokens")
- cc1h := gjson.Get(data, "usage.cache_creation.ephemeral_1h_input_tokens")
- if cc5m.Exists() && cc5m.Int() > 0 {
- usage.CacheCreation5mTokens = int(cc5m.Int())
+ patch := &sseUsagePatch{}
+ patch.hasInputTokens = true
+ if v, ok := parseSSEUsageInt(usageObj["input_tokens"]); ok {
+ patch.inputTokens = v
}
- if cc1h.Exists() && cc1h.Int() > 0 {
- usage.CacheCreation1hTokens = int(cc1h.Int())
+ patch.hasCacheCreationInput = true
+ if v, ok := parseSSEUsageInt(usageObj["cache_creation_input_tokens"]); ok {
+ patch.cacheCreationInputTokens = v
+ }
+ patch.hasCacheReadInput = true
+ if v, ok := parseSSEUsageInt(usageObj["cache_read_input_tokens"]); ok {
+ patch.cacheReadInputTokens = v
+ }
+ if cc, ok := usageObj["cache_creation"].(map[string]any); ok {
+ if v, exists := parseSSEUsageInt(cc["ephemeral_5m_input_tokens"]); exists {
+ patch.cacheCreation5mTokens = v
+ patch.hasCacheCreation5m = true
+ }
+ if v, exists := parseSSEUsageInt(cc["ephemeral_1h_input_tokens"]); exists {
+ patch.cacheCreation1hTokens = v
+ patch.hasCacheCreation1h = true
+ }
+ }
+ return patch
+
+ case "message_delta":
+ usageObj, _ := event["usage"].(map[string]any)
+ if len(usageObj) == 0 {
+ return nil
+ }
+
+ patch := &sseUsagePatch{}
+ if v, ok := parseSSEUsageInt(usageObj["input_tokens"]); ok && v > 0 {
+ patch.inputTokens = v
+ patch.hasInputTokens = true
+ }
+ if v, ok := parseSSEUsageInt(usageObj["output_tokens"]); ok && v > 0 {
+ patch.outputTokens = v
+ patch.hasOutputTokens = true
+ }
+ if v, ok := parseSSEUsageInt(usageObj["cache_creation_input_tokens"]); ok && v > 0 {
+ patch.cacheCreationInputTokens = v
+ patch.hasCacheCreationInput = true
+ }
+ if v, ok := parseSSEUsageInt(usageObj["cache_read_input_tokens"]); ok && v > 0 {
+ patch.cacheReadInputTokens = v
+ patch.hasCacheReadInput = true
+ }
+ if cc, ok := usageObj["cache_creation"].(map[string]any); ok {
+ if v, exists := parseSSEUsageInt(cc["ephemeral_5m_input_tokens"]); exists && v > 0 {
+ patch.cacheCreation5mTokens = v
+ patch.hasCacheCreation5m = true
+ }
+ if v, exists := parseSSEUsageInt(cc["ephemeral_1h_input_tokens"]); exists && v > 0 {
+ patch.cacheCreation1hTokens = v
+ patch.hasCacheCreation1h = true
+ }
+ }
+ return patch
+ }
+
+ return nil
+}
+
+func mergeSSEUsagePatch(usage *ClaudeUsage, patch *sseUsagePatch) {
+ if usage == nil || patch == nil {
+ return
+ }
+
+ if patch.hasInputTokens {
+ usage.InputTokens = patch.inputTokens
+ }
+ if patch.hasCacheCreationInput {
+ usage.CacheCreationInputTokens = patch.cacheCreationInputTokens
+ }
+ if patch.hasCacheReadInput {
+ usage.CacheReadInputTokens = patch.cacheReadInputTokens
+ }
+ if patch.hasOutputTokens {
+ usage.OutputTokens = patch.outputTokens
+ }
+ if patch.hasCacheCreation5m {
+ usage.CacheCreation5mTokens = patch.cacheCreation5mTokens
+ }
+ if patch.hasCacheCreation1h {
+ usage.CacheCreation1hTokens = patch.cacheCreation1hTokens
+ }
+}
+
+func parseSSEUsageInt(value any) (int, bool) {
+ switch v := value.(type) {
+ case float64:
+ return int(v), true
+ case float32:
+ return int(v), true
+ case int:
+ return v, true
+ case int64:
+ return int(v), true
+ case int32:
+ return int(v), true
+ case json.Number:
+ if i, err := v.Int64(); err == nil {
+ return int(i), true
+ }
+ if f, err := v.Float64(); err == nil {
+ return int(f), true
+ }
+ case string:
+ if parsed, err := strconv.Atoi(strings.TrimSpace(v)); err == nil {
+ return parsed, true
}
}
+ return 0, false
}
// applyCacheTTLOverride 将所有 cache creation tokens 归入指定的 TTL 类型。
@@ -5391,25 +6160,32 @@ func applyCacheTTLOverride(usage *ClaudeUsage, target string) bool {
// rewriteCacheCreationJSON 在 JSON usage 对象中重写 cache_creation 嵌套对象的 TTL 分类。
// usageObj 是 usage JSON 对象(map[string]any)。
-func rewriteCacheCreationJSON(usageObj map[string]any, target string) {
+func rewriteCacheCreationJSON(usageObj map[string]any, target string) bool {
ccObj, ok := usageObj["cache_creation"].(map[string]any)
if !ok {
- return
+ return false
}
- v5m, _ := ccObj["ephemeral_5m_input_tokens"].(float64)
- v1h, _ := ccObj["ephemeral_1h_input_tokens"].(float64)
+ v5m, _ := parseSSEUsageInt(ccObj["ephemeral_5m_input_tokens"])
+ v1h, _ := parseSSEUsageInt(ccObj["ephemeral_1h_input_tokens"])
total := v5m + v1h
if total == 0 {
- return
+ return false
}
switch target {
case "1h":
- ccObj["ephemeral_1h_input_tokens"] = total
+ if v1h == total {
+ return false
+ }
+ ccObj["ephemeral_1h_input_tokens"] = float64(total)
ccObj["ephemeral_5m_input_tokens"] = float64(0)
default: // "5m"
- ccObj["ephemeral_5m_input_tokens"] = total
+ if v5m == total {
+ return false
+ }
+ ccObj["ephemeral_5m_input_tokens"] = float64(total)
ccObj["ephemeral_1h_input_tokens"] = float64(0)
}
+ return true
}
func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*ClaudeUsage, error) {
@@ -5478,7 +6254,7 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
}
- responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
+ responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
contentType := "application/json"
if s.cfg != nil && !s.cfg.Security.ResponseHeaders.Enabled {
@@ -5993,9 +6769,10 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
body, reqModel = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts)
}
- // Antigravity 账户不支持 count_tokens 转发,直接返回空值
+ // Antigravity 账户不支持 count_tokens,返回 404 让客户端 fallback 到本地估算。
+ // 返回 nil 避免 handler 层记录为错误,也不设置 ops 上游错误上下文。
if account.Platform == PlatformAntigravity {
- c.JSON(http.StatusOK, gin.H{"input_tokens": 0})
+ s.countTokensError(c, http.StatusNotFound, "not_found_error", "count_tokens endpoint is not supported for this platform")
return nil
}
@@ -6199,6 +6976,18 @@ func (s *GatewayService) forwardCountTokensAnthropicAPIKeyPassthrough(ctx contex
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
+
+ // 中转站不支持 count_tokens 端点时(404),返回 404 让客户端 fallback 到本地估算。
+ // 仅在错误消息明确指向 count_tokens endpoint 不存在时生效,避免误吞其他 404(如错误 base_url)。
+ // 返回 nil 避免 handler 层记录为错误,也不设置 ops 上游错误上下文。
+ if isCountTokensUnsupported404(resp.StatusCode, respBody) {
+ logger.LegacyPrintf("service.gateway",
+ "[count_tokens] Upstream does not support count_tokens (404), returning 404: account=%d name=%s msg=%s",
+ account.ID, account.Name, truncateString(upstreamMsg, 512))
+ s.countTokensError(c, http.StatusNotFound, "not_found_error", "count_tokens endpoint is not supported by upstream")
+ return nil
+ }
+
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
@@ -6234,7 +7023,7 @@ func (s *GatewayService) forwardCountTokensAnthropicAPIKeyPassthrough(ctx contex
return fmt.Errorf("upstream error: %d message=%s", resp.StatusCode, upstreamMsg)
}
- writeAnthropicPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.cfg)
+ writeAnthropicPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
contentType := strings.TrimSpace(resp.Header.Get("Content-Type"))
if contentType == "" {
contentType = "application/json"
@@ -6375,7 +7164,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
incomingBeta := req.Header.Get("anthropic-beta")
requiredBetas := []string{claude.BetaClaudeCode, claude.BetaOAuth, claude.BetaInterleavedThinking, claude.BetaTokenCounting}
- drop := map[string]struct{}{claude.BetaContext1M: {}}
+ drop := droppedBetaSet()
req.Header.Set("anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, drop))
} else {
clientBetaHeader := req.Header.Get("anthropic-beta")
@@ -6386,7 +7175,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
if !strings.Contains(beta, claude.BetaTokenCounting) {
beta = beta + "," + claude.BetaTokenCounting
}
- req.Header.Set("anthropic-beta", stripBetaToken(beta, claude.BetaContext1M))
+ req.Header.Set("anthropic-beta", stripBetaTokensWithSet(beta, defaultDroppedBetasSet))
}
}
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" {
diff --git a/backend/internal/service/gateway_service_selection_failure_stats_test.go b/backend/internal/service/gateway_service_selection_failure_stats_test.go
new file mode 100644
index 00000000..743d70bb
--- /dev/null
+++ b/backend/internal/service/gateway_service_selection_failure_stats_test.go
@@ -0,0 +1,141 @@
+package service
+
+import (
+ "context"
+ "strings"
+ "testing"
+ "time"
+)
+
+func TestCollectSelectionFailureStats(t *testing.T) {
+ svc := &GatewayService{}
+ model := "sora2-landscape-10s"
+ resetAt := time.Now().Add(2 * time.Minute).Format(time.RFC3339)
+
+ accounts := []Account{
+ // excluded
+ {
+ ID: 1,
+ Platform: PlatformSora,
+ Status: StatusActive,
+ Schedulable: true,
+ },
+ // unschedulable
+ {
+ ID: 2,
+ Platform: PlatformSora,
+ Status: StatusActive,
+ Schedulable: false,
+ },
+ // platform filtered
+ {
+ ID: 3,
+ Platform: PlatformOpenAI,
+ Status: StatusActive,
+ Schedulable: true,
+ },
+ // model unsupported
+ {
+ ID: 4,
+ Platform: PlatformSora,
+ Status: StatusActive,
+ Schedulable: true,
+ Credentials: map[string]any{
+ "model_mapping": map[string]any{
+ "gpt-image": "gpt-image",
+ },
+ },
+ },
+ // model rate limited
+ {
+ ID: 5,
+ Platform: PlatformSora,
+ Status: StatusActive,
+ Schedulable: true,
+ Extra: map[string]any{
+ "model_rate_limits": map[string]any{
+ model: map[string]any{
+ "rate_limit_reset_at": resetAt,
+ },
+ },
+ },
+ },
+ // eligible
+ {
+ ID: 6,
+ Platform: PlatformSora,
+ Status: StatusActive,
+ Schedulable: true,
+ },
+ }
+
+ excluded := map[int64]struct{}{1: {}}
+ stats := svc.collectSelectionFailureStats(context.Background(), accounts, model, PlatformSora, excluded, false)
+
+ if stats.Total != 6 {
+ t.Fatalf("total=%d want=6", stats.Total)
+ }
+ if stats.Excluded != 1 {
+ t.Fatalf("excluded=%d want=1", stats.Excluded)
+ }
+ if stats.Unschedulable != 1 {
+ t.Fatalf("unschedulable=%d want=1", stats.Unschedulable)
+ }
+ if stats.PlatformFiltered != 1 {
+ t.Fatalf("platform_filtered=%d want=1", stats.PlatformFiltered)
+ }
+ if stats.ModelUnsupported != 1 {
+ t.Fatalf("model_unsupported=%d want=1", stats.ModelUnsupported)
+ }
+ if stats.ModelRateLimited != 1 {
+ t.Fatalf("model_rate_limited=%d want=1", stats.ModelRateLimited)
+ }
+ if stats.Eligible != 1 {
+ t.Fatalf("eligible=%d want=1", stats.Eligible)
+ }
+}
+
+func TestDiagnoseSelectionFailure_SoraUnschedulableDetail(t *testing.T) {
+ svc := &GatewayService{}
+ acc := &Account{
+ ID: 7,
+ Platform: PlatformSora,
+ Status: StatusActive,
+ Schedulable: false,
+ }
+
+ diagnosis := svc.diagnoseSelectionFailure(context.Background(), acc, "sora2-landscape-10s", PlatformSora, map[int64]struct{}{}, false)
+ if diagnosis.Category != "unschedulable" {
+ t.Fatalf("category=%s want=unschedulable", diagnosis.Category)
+ }
+ if diagnosis.Detail != "schedulable=false" {
+ t.Fatalf("detail=%s want=schedulable=false", diagnosis.Detail)
+ }
+}
+
+func TestDiagnoseSelectionFailure_SoraModelRateLimitedDetail(t *testing.T) {
+ svc := &GatewayService{}
+ model := "sora2-landscape-10s"
+ resetAt := time.Now().Add(2 * time.Minute).UTC().Format(time.RFC3339)
+ acc := &Account{
+ ID: 8,
+ Platform: PlatformSora,
+ Status: StatusActive,
+ Schedulable: true,
+ Extra: map[string]any{
+ "model_rate_limits": map[string]any{
+ model: map[string]any{
+ "rate_limit_reset_at": resetAt,
+ },
+ },
+ },
+ }
+
+ diagnosis := svc.diagnoseSelectionFailure(context.Background(), acc, model, PlatformSora, map[int64]struct{}{}, false)
+ if diagnosis.Category != "model_rate_limited" {
+ t.Fatalf("category=%s want=model_rate_limited", diagnosis.Category)
+ }
+ if !strings.Contains(diagnosis.Detail, "remaining=") {
+ t.Fatalf("detail=%s want contains remaining=", diagnosis.Detail)
+ }
+}
diff --git a/backend/internal/service/gateway_service_sora_model_support_test.go b/backend/internal/service/gateway_service_sora_model_support_test.go
new file mode 100644
index 00000000..8ee2a960
--- /dev/null
+++ b/backend/internal/service/gateway_service_sora_model_support_test.go
@@ -0,0 +1,79 @@
+package service
+
+import "testing"
+
+func TestGatewayServiceIsModelSupportedByAccount_SoraNoMappingAllowsAll(t *testing.T) {
+ svc := &GatewayService{}
+ account := &Account{
+ Platform: PlatformSora,
+ Credentials: map[string]any{},
+ }
+
+ if !svc.isModelSupportedByAccount(account, "sora2-landscape-10s") {
+ t.Fatalf("expected sora model to be supported when model_mapping is empty")
+ }
+}
+
+func TestGatewayServiceIsModelSupportedByAccount_SoraLegacyNonSoraMappingDoesNotBlock(t *testing.T) {
+ svc := &GatewayService{}
+ account := &Account{
+ Platform: PlatformSora,
+ Credentials: map[string]any{
+ "model_mapping": map[string]any{
+ "gpt-4o": "gpt-4o",
+ },
+ },
+ }
+
+ if !svc.isModelSupportedByAccount(account, "sora2-landscape-10s") {
+ t.Fatalf("expected sora model to be supported when mapping has no sora selectors")
+ }
+}
+
+func TestGatewayServiceIsModelSupportedByAccount_SoraFamilyAlias(t *testing.T) {
+ svc := &GatewayService{}
+ account := &Account{
+ Platform: PlatformSora,
+ Credentials: map[string]any{
+ "model_mapping": map[string]any{
+ "sora2": "sora2",
+ },
+ },
+ }
+
+ if !svc.isModelSupportedByAccount(account, "sora2-landscape-15s") {
+ t.Fatalf("expected family selector sora2 to support sora2-landscape-15s")
+ }
+}
+
+func TestGatewayServiceIsModelSupportedByAccount_SoraUnderlyingModelAlias(t *testing.T) {
+ svc := &GatewayService{}
+ account := &Account{
+ Platform: PlatformSora,
+ Credentials: map[string]any{
+ "model_mapping": map[string]any{
+ "sy_8": "sy_8",
+ },
+ },
+ }
+
+ if !svc.isModelSupportedByAccount(account, "sora2-landscape-10s") {
+ t.Fatalf("expected underlying model selector sy_8 to support sora2-landscape-10s")
+ }
+}
+
+func TestGatewayServiceIsModelSupportedByAccount_SoraExplicitImageSelectorBlocksVideo(t *testing.T) {
+ svc := &GatewayService{}
+ account := &Account{
+ Platform: PlatformSora,
+ Credentials: map[string]any{
+ "model_mapping": map[string]any{
+ "gpt-image": "gpt-image",
+ },
+ },
+ }
+
+ if svc.isModelSupportedByAccount(account, "sora2-landscape-10s") {
+ t.Fatalf("expected video model to be blocked when mapping explicitly only allows gpt-image")
+ }
+}
diff --git a/backend/internal/service/gateway_service_sora_scheduling_test.go b/backend/internal/service/gateway_service_sora_scheduling_test.go
new file mode 100644
index 00000000..5178e68e
--- /dev/null
+++ b/backend/internal/service/gateway_service_sora_scheduling_test.go
@@ -0,0 +1,89 @@
+package service
+
+import (
+ "context"
+ "testing"
+ "time"
+)
+
+func TestGatewayServiceIsAccountSchedulableForSelectionSoraIgnoresGenericWindows(t *testing.T) {
+ svc := &GatewayService{}
+ now := time.Now()
+ past := now.Add(-1 * time.Minute)
+ future := now.Add(5 * time.Minute)
+
+ acc := &Account{
+ Platform: PlatformSora,
+ Status: StatusActive,
+ Schedulable: true,
+ AutoPauseOnExpired: true,
+ ExpiresAt: &past,
+ OverloadUntil: &future,
+ RateLimitResetAt: &future,
+ }
+
+ if !svc.isAccountSchedulableForSelection(acc) {
+ t.Fatalf("expected sora account to ignore generic expiry/overload/rate-limit windows")
+ }
+}
+
+func TestGatewayServiceIsAccountSchedulableForSelectionNonSoraKeepsGenericLogic(t *testing.T) {
+ svc := &GatewayService{}
+ future := time.Now().Add(5 * time.Minute)
+
+ acc := &Account{
+ Platform: PlatformAnthropic,
+ Status: StatusActive,
+ Schedulable: true,
+ RateLimitResetAt: &future,
+ }
+
+ if svc.isAccountSchedulableForSelection(acc) {
+ t.Fatalf("expected non-sora account to keep generic schedulable checks")
+ }
+}
+
+func TestGatewayServiceIsAccountSchedulableForModelSelectionSoraChecksModelScopeOnly(t *testing.T) {
+ svc := &GatewayService{}
+ model := "sora2-landscape-10s"
+ resetAt := time.Now().Add(2 * time.Minute).UTC().Format(time.RFC3339)
+ globalResetAt := time.Now().Add(2 * time.Minute)
+
+ acc := &Account{
+ Platform: PlatformSora,
+ Status: StatusActive,
+ Schedulable: true,
+ RateLimitResetAt: &globalResetAt,
+ Extra: map[string]any{
+ "model_rate_limits": map[string]any{
+ model: map[string]any{
+ "rate_limit_reset_at": resetAt,
+ },
+ },
+ },
+ }
+
+ if svc.isAccountSchedulableForModelSelection(context.Background(), acc, model) {
+ t.Fatalf("expected sora account to be blocked by model scope rate limit")
+ }
+}
+
+func TestCollectSelectionFailureStatsSoraIgnoresGenericUnschedulableWindows(t *testing.T) {
+ svc := &GatewayService{}
+ future := time.Now().Add(3 * time.Minute)
+
+ accounts := []Account{
+ {
+ ID: 1,
+ Platform: PlatformSora,
+ Status: StatusActive,
+ Schedulable: true,
+ RateLimitResetAt: &future,
+ },
+ }
+
+ stats := svc.collectSelectionFailureStats(context.Background(), accounts, "sora2-landscape-10s", PlatformSora, map[int64]struct{}{}, false)
+ if stats.Unschedulable != 0 || stats.Eligible != 1 {
+ t.Fatalf("unexpected stats: unschedulable=%d eligible=%d", stats.Unschedulable, stats.Eligible)
+ }
+}
diff --git a/backend/internal/service/gateway_waiting_queue_test.go b/backend/internal/service/gateway_waiting_queue_test.go
index 0ed95c87..0c53323e 100644
--- a/backend/internal/service/gateway_waiting_queue_test.go
+++ b/backend/internal/service/gateway_waiting_queue_test.go
@@ -105,12 +105,12 @@ func TestCalculateMaxWait_Scenarios(t *testing.T) {
concurrency int
expected int
}{
- {5, 25}, // 5 + 20
- {10, 30}, // 10 + 20
- {1, 21}, // 1 + 20
- {0, 21}, // min(1) + 20
- {-1, 21}, // min(1) + 20
- {-10, 21}, // min(1) + 20
+ {5, 25}, // 5 + 20
+ {10, 30}, // 10 + 20
+ {1, 21}, // 1 + 20
+ {0, 21}, // min(1) + 20
+ {-1, 21}, // min(1) + 20
+ {-10, 21}, // min(1) + 20
{100, 120}, // 100 + 20
}
for _, tt := range tests {
diff --git a/backend/internal/service/gemini_messages_compat_service.go b/backend/internal/service/gemini_messages_compat_service.go
index 8670f99a..1c38b6c2 100644
--- a/backend/internal/service/gemini_messages_compat_service.go
+++ b/backend/internal/service/gemini_messages_compat_service.go
@@ -53,6 +53,7 @@ type GeminiMessagesCompatService struct {
httpUpstream HTTPUpstream
antigravityGatewayService *AntigravityGatewayService
cfg *config.Config
+ responseHeaderFilter *responseheaders.CompiledHeaderFilter
}
func NewGeminiMessagesCompatService(
@@ -76,6 +77,7 @@ func NewGeminiMessagesCompatService(
httpUpstream: httpUpstream,
antigravityGatewayService: antigravityGatewayService,
cfg: cfg,
+ responseHeaderFilter: compileResponseHeaderFilter(cfg),
}
}
@@ -229,6 +231,16 @@ func (s *GeminiMessagesCompatService) isAccountUsableForRequest(
account *Account,
requestedModel, platform string,
useMixedScheduling bool,
+) bool {
+ return s.isAccountUsableForRequestWithPrecheck(ctx, account, requestedModel, platform, useMixedScheduling, nil)
+}
+
+func (s *GeminiMessagesCompatService) isAccountUsableForRequestWithPrecheck(
+ ctx context.Context,
+ account *Account,
+ requestedModel, platform string,
+ useMixedScheduling bool,
+ precheckResult map[int64]bool,
) bool {
// 检查模型调度能力
// Check model scheduling capability
@@ -250,7 +262,7 @@ func (s *GeminiMessagesCompatService) isAccountUsableForRequest(
// 速率限制预检
// Rate limit precheck
- if !s.passesRateLimitPreCheck(ctx, account, requestedModel) {
+ if !s.passesRateLimitPreCheckWithCache(ctx, account, requestedModel, precheckResult) {
return false
}
@@ -272,15 +284,17 @@ func (s *GeminiMessagesCompatService) isAccountValidForPlatform(account *Account
return false
}
-// passesRateLimitPreCheck 执行速率限制预检。
-// 返回 true 表示通过预检或无需预检。
-//
-// passesRateLimitPreCheck performs rate limit precheck.
-// Returns true if passed or precheck not required.
-func (s *GeminiMessagesCompatService) passesRateLimitPreCheck(ctx context.Context, account *Account, requestedModel string) bool {
+func (s *GeminiMessagesCompatService) passesRateLimitPreCheckWithCache(ctx context.Context, account *Account, requestedModel string, precheckResult map[int64]bool) bool {
if s.rateLimitService == nil || requestedModel == "" {
return true
}
+
+ if precheckResult != nil {
+ if ok, exists := precheckResult[account.ID]; exists {
+ return ok
+ }
+ }
+
ok, err := s.rateLimitService.PreCheckUsage(ctx, account, requestedModel)
if err != nil {
logger.LegacyPrintf("service.gemini_messages_compat", "[Gemini PreCheck] Account %d precheck error: %v", account.ID, err)
@@ -302,6 +316,7 @@ func (s *GeminiMessagesCompatService) selectBestGeminiAccount(
useMixedScheduling bool,
) *Account {
var selected *Account
+ precheckResult := s.buildPreCheckUsageResultMap(ctx, accounts, requestedModel)
for i := range accounts {
acc := &accounts[i]
@@ -312,7 +327,7 @@ func (s *GeminiMessagesCompatService) selectBestGeminiAccount(
}
// 检查账号是否可用于当前请求
- if !s.isAccountUsableForRequest(ctx, acc, requestedModel, platform, useMixedScheduling) {
+ if !s.isAccountUsableForRequestWithPrecheck(ctx, acc, requestedModel, platform, useMixedScheduling, precheckResult) {
continue
}
@@ -330,6 +345,23 @@ func (s *GeminiMessagesCompatService) selectBestGeminiAccount(
return selected
}
+func (s *GeminiMessagesCompatService) buildPreCheckUsageResultMap(ctx context.Context, accounts []Account, requestedModel string) map[int64]bool {
+ if s.rateLimitService == nil || requestedModel == "" || len(accounts) == 0 {
+ return nil
+ }
+
+ candidates := make([]*Account, 0, len(accounts))
+ for i := range accounts {
+ candidates = append(candidates, &accounts[i])
+ }
+
+ result, err := s.rateLimitService.PreCheckUsageBatch(ctx, candidates, requestedModel)
+ if err != nil {
+ logger.LegacyPrintf("service.gemini_messages_compat", "[Gemini PreCheckBatch] failed: %v", err)
+ }
+ return result
+}
+
// isBetterGeminiAccount 判断 candidate 是否比 current 更优。
// 规则:优先级更高(数值更小)优先;同优先级时,未使用过的优先(OAuth > 非 OAuth),其次是最久未使用的。
//
@@ -2390,7 +2422,7 @@ func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Co
}
}
- responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
+ responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
contentType := resp.Header.Get("Content-Type")
if contentType == "" {
@@ -2415,8 +2447,8 @@ func (s *GeminiMessagesCompatService) handleNativeStreamingResponse(c *gin.Conte
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ====================================================")
}
- if s.cfg != nil {
- responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
+ if s.responseHeaderFilter != nil {
+ responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
}
c.Status(resp.StatusCode)
@@ -2557,7 +2589,7 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac
body, _ := io.ReadAll(io.LimitReader(resp.Body, 8<<20))
wwwAuthenticate := resp.Header.Get("Www-Authenticate")
- filteredHeaders := responseheaders.FilterHeaders(resp.Header, s.cfg.Security.ResponseHeaders)
+ filteredHeaders := responseheaders.FilterHeaders(resp.Header, s.responseHeaderFilter)
if wwwAuthenticate != "" {
filteredHeaders.Set("Www-Authenticate", wwwAuthenticate)
}
diff --git a/backend/internal/service/gemini_oauth_service.go b/backend/internal/service/gemini_oauth_service.go
index 0b9734f6..e866bdc3 100644
--- a/backend/internal/service/gemini_oauth_service.go
+++ b/backend/internal/service/gemini_oauth_service.go
@@ -54,6 +54,7 @@ type GeminiOAuthService struct {
proxyRepo ProxyRepository
oauthClient GeminiOAuthClient
codeAssist GeminiCliCodeAssistClient
+ driveClient geminicli.DriveClient
cfg *config.Config
}
@@ -66,6 +67,7 @@ func NewGeminiOAuthService(
proxyRepo ProxyRepository,
oauthClient GeminiOAuthClient,
codeAssist GeminiCliCodeAssistClient,
+ driveClient geminicli.DriveClient,
cfg *config.Config,
) *GeminiOAuthService {
return &GeminiOAuthService{
@@ -73,6 +75,7 @@ func NewGeminiOAuthService(
proxyRepo: proxyRepo,
oauthClient: oauthClient,
codeAssist: codeAssist,
+ driveClient: driveClient,
cfg: cfg,
}
}
@@ -362,9 +365,8 @@ func (s *GeminiOAuthService) FetchGoogleOneTier(ctx context.Context, accessToken
// Use Drive API to infer tier from storage quota (requires drive.readonly scope)
logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Calling Drive API for storage quota...")
- driveClient := geminicli.NewDriveClient()
- storageInfo, err := driveClient.GetStorageQuota(ctx, accessToken, proxyURL)
+ storageInfo, err := s.driveClient.GetStorageQuota(ctx, accessToken, proxyURL)
if err != nil {
// Check if it's a 403 (scope not granted)
if strings.Contains(err.Error(), "status 403") {
diff --git a/backend/internal/service/gemini_oauth_service_test.go b/backend/internal/service/gemini_oauth_service_test.go
index c58a5930..397b581d 100644
--- a/backend/internal/service/gemini_oauth_service_test.go
+++ b/backend/internal/service/gemini_oauth_service_test.go
@@ -101,7 +101,7 @@ func TestGeminiOAuthService_GenerateAuthURL_RedirectURIStrategy(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
- svc := NewGeminiOAuthService(nil, nil, nil, tt.cfg)
+ svc := NewGeminiOAuthService(nil, nil, nil, nil, tt.cfg)
got, err := svc.GenerateAuthURL(context.Background(), nil, "https://example.com/auth/callback", tt.projectID, tt.oauthType, "")
if tt.wantErrSubstr != "" {
if err == nil {
@@ -487,7 +487,7 @@ func TestIsNonRetryableGeminiOAuthError(t *testing.T) {
func TestGeminiOAuthService_BuildAccountCredentials(t *testing.T) {
t.Parallel()
- svc := NewGeminiOAuthService(nil, nil, nil, &config.Config{})
+ svc := NewGeminiOAuthService(nil, nil, nil, nil, &config.Config{})
defer svc.Stop()
t.Run("完整字段", func(t *testing.T) {
@@ -687,7 +687,7 @@ func TestGeminiOAuthService_GetOAuthConfig(t *testing.T) {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
- svc := NewGeminiOAuthService(nil, nil, nil, tt.cfg)
+ svc := NewGeminiOAuthService(nil, nil, nil, nil, tt.cfg)
defer svc.Stop()
result := svc.GetOAuthConfig()
@@ -709,7 +709,7 @@ func TestGeminiOAuthService_GetOAuthConfig(t *testing.T) {
func TestGeminiOAuthService_Stop_NoPanic(t *testing.T) {
t.Parallel()
- svc := NewGeminiOAuthService(nil, nil, nil, &config.Config{})
+ svc := NewGeminiOAuthService(nil, nil, nil, nil, &config.Config{})
// 调用 Stop 不应 panic
svc.Stop()
@@ -806,6 +806,18 @@ func (m *mockGeminiProxyRepo) ListAccountSummariesByProxyID(ctx context.Context,
panic("not impl")
}
+// mockDriveClient implements geminicli.DriveClient for tests.
+type mockDriveClient struct {
+ getStorageQuotaFunc func(ctx context.Context, accessToken, proxyURL string) (*geminicli.DriveStorageInfo, error)
+}
+
+func (m *mockDriveClient) GetStorageQuota(ctx context.Context, accessToken, proxyURL string) (*geminicli.DriveStorageInfo, error) {
+ if m.getStorageQuotaFunc != nil {
+ return m.getStorageQuotaFunc(ctx, accessToken, proxyURL)
+ }
+ return nil, fmt.Errorf("drive API not available in test")
+}
+
// =====================
// 新增测试:GeminiOAuthService.RefreshToken(含重试逻辑)
// =====================
@@ -825,7 +837,7 @@ func TestGeminiOAuthService_RefreshToken_Success(t *testing.T) {
},
}
- svc := NewGeminiOAuthService(nil, client, nil, &config.Config{})
+ svc := NewGeminiOAuthService(nil, client, nil, nil, &config.Config{})
defer svc.Stop()
info, err := svc.RefreshToken(context.Background(), "code_assist", "old-refresh", "")
@@ -852,7 +864,7 @@ func TestGeminiOAuthService_RefreshToken_NonRetryableError(t *testing.T) {
},
}
- svc := NewGeminiOAuthService(nil, client, nil, &config.Config{})
+ svc := NewGeminiOAuthService(nil, client, nil, nil, &config.Config{})
defer svc.Stop()
_, err := svc.RefreshToken(context.Background(), "code_assist", "revoked-token", "")
@@ -881,7 +893,7 @@ func TestGeminiOAuthService_RefreshToken_RetryableError(t *testing.T) {
},
}
- svc := NewGeminiOAuthService(nil, client, nil, &config.Config{})
+ svc := NewGeminiOAuthService(nil, client, nil, nil, &config.Config{})
defer svc.Stop()
info, err := svc.RefreshToken(context.Background(), "code_assist", "rt", "")
@@ -903,7 +915,7 @@ func TestGeminiOAuthService_RefreshToken_RetryableError(t *testing.T) {
func TestGeminiOAuthService_RefreshAccountToken_NotGeminiOAuth(t *testing.T) {
t.Parallel()
- svc := NewGeminiOAuthService(nil, nil, nil, &config.Config{})
+ svc := NewGeminiOAuthService(nil, nil, nil, nil, &config.Config{})
defer svc.Stop()
account := &Account{
@@ -923,7 +935,7 @@ func TestGeminiOAuthService_RefreshAccountToken_NotGeminiOAuth(t *testing.T) {
func TestGeminiOAuthService_RefreshAccountToken_NoRefreshToken(t *testing.T) {
t.Parallel()
- svc := NewGeminiOAuthService(nil, nil, nil, &config.Config{})
+ svc := NewGeminiOAuthService(nil, nil, nil, nil, &config.Config{})
defer svc.Stop()
account := &Account{
@@ -958,7 +970,7 @@ func TestGeminiOAuthService_RefreshAccountToken_AIStudio(t *testing.T) {
},
}
- svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, &config.Config{})
+ svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, nil, &config.Config{})
defer svc.Stop()
account := &Account{
@@ -997,7 +1009,7 @@ func TestGeminiOAuthService_RefreshAccountToken_CodeAssist_WithProjectID(t *test
},
}
- svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, &config.Config{})
+ svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, nil, &config.Config{})
defer svc.Stop()
account := &Account{
@@ -1042,7 +1054,7 @@ func TestGeminiOAuthService_RefreshAccountToken_DefaultOAuthType(t *testing.T) {
},
}
- svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, &config.Config{})
+ svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, nil, &config.Config{})
defer svc.Stop()
// 无 oauth_type 凭据的旧账号
@@ -1090,7 +1102,7 @@ func TestGeminiOAuthService_RefreshAccountToken_WithProxy(t *testing.T) {
},
}
- svc := NewGeminiOAuthService(proxyRepo, client, nil, &config.Config{})
+ svc := NewGeminiOAuthService(proxyRepo, client, nil, nil, &config.Config{})
defer svc.Stop()
proxyID := int64(5)
@@ -1132,7 +1144,7 @@ func TestGeminiOAuthService_RefreshAccountToken_CodeAssist_NoProjectID_AutoDetec
},
}
- svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, codeAssist, &config.Config{})
+ svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, codeAssist, nil, &config.Config{})
defer svc.Stop()
account := &Account{
@@ -1181,7 +1193,7 @@ func TestGeminiOAuthService_RefreshAccountToken_CodeAssist_NoProjectID_FailsEmpt
},
}
- svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, codeAssist, &config.Config{})
+ svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, codeAssist, nil, &config.Config{})
defer svc.Stop()
account := &Account{
@@ -1214,7 +1226,7 @@ func TestGeminiOAuthService_RefreshAccountToken_GoogleOne_FreshCache(t *testing.
},
}
- svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, &config.Config{})
+ svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, nil, &config.Config{})
defer svc.Stop()
account := &Account{
@@ -1254,7 +1266,7 @@ func TestGeminiOAuthService_RefreshAccountToken_GoogleOne_NoTierID_DefaultsFree(
},
}
- svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, &config.Config{})
+ svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, &mockDriveClient{}, &config.Config{})
defer svc.Stop()
account := &Account{
@@ -1308,7 +1320,7 @@ func TestGeminiOAuthService_RefreshAccountToken_UnauthorizedClient_Fallback(t *t
},
}
- svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, cfg)
+ svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, nil, cfg)
defer svc.Stop()
account := &Account{
@@ -1341,7 +1353,7 @@ func TestGeminiOAuthService_RefreshAccountToken_UnauthorizedClient_NoFallback(t
}
// 无自定义 OAuth 客户端,无法 fallback
- svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, &config.Config{})
+ svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, nil, &config.Config{})
defer svc.Stop()
account := &Account{
@@ -1370,7 +1382,7 @@ func TestGeminiOAuthService_RefreshAccountToken_UnauthorizedClient_NoFallback(t
func TestGeminiOAuthService_ExchangeCode_SessionNotFound(t *testing.T) {
t.Parallel()
- svc := NewGeminiOAuthService(nil, nil, nil, &config.Config{})
+ svc := NewGeminiOAuthService(nil, nil, nil, nil, &config.Config{})
defer svc.Stop()
_, err := svc.ExchangeCode(context.Background(), &GeminiExchangeCodeInput{
@@ -1389,7 +1401,7 @@ func TestGeminiOAuthService_ExchangeCode_SessionNotFound(t *testing.T) {
func TestGeminiOAuthService_ExchangeCode_InvalidState(t *testing.T) {
t.Parallel()
- svc := NewGeminiOAuthService(nil, nil, nil, &config.Config{})
+ svc := NewGeminiOAuthService(nil, nil, nil, nil, &config.Config{})
defer svc.Stop()
// 手动创建 session(必须设置 CreatedAt,否则会因 TTL 过期被拒绝)
@@ -1416,7 +1428,7 @@ func TestGeminiOAuthService_ExchangeCode_InvalidState(t *testing.T) {
func TestGeminiOAuthService_ExchangeCode_EmptyState(t *testing.T) {
t.Parallel()
- svc := NewGeminiOAuthService(nil, nil, nil, &config.Config{})
+ svc := NewGeminiOAuthService(nil, nil, nil, nil, &config.Config{})
defer svc.Stop()
svc.sessionStore.Set("test-session", &geminicli.OAuthSession{
diff --git a/backend/internal/service/group.go b/backend/internal/service/group.go
index 86ece03f..6990caca 100644
--- a/backend/internal/service/group.go
+++ b/backend/internal/service/group.go
@@ -32,6 +32,9 @@ type Group struct {
SoraVideoPricePerRequest *float64
SoraVideoPricePerRequestHD *float64
+ // Sora 存储配额
+ SoraStorageQuotaBytes int64
+
// Claude Code 客户端限制
ClaudeCodeOnly bool
FallbackGroupID *int64
diff --git a/backend/internal/service/identity_service.go b/backend/internal/service/identity_service.go
index dc59010d..f3130c91 100644
--- a/backend/internal/service/identity_service.go
+++ b/backend/internal/service/identity_service.go
@@ -46,6 +46,7 @@ type Fingerprint struct {
StainlessArch string
StainlessRuntime string
StainlessRuntimeVersion string
+ UpdatedAt int64 `json:",omitempty"` // Unix timestamp,用于判断是否需要续期TTL
}
// IdentityCache defines cache operations for identity service
@@ -78,14 +79,26 @@ func (s *IdentityService) GetOrCreateFingerprint(ctx context.Context, accountID
// 尝试从缓存获取指纹
cached, err := s.cache.GetFingerprint(ctx, accountID)
if err == nil && cached != nil {
+ needWrite := false
+
// 检查客户端的user-agent是否是更新版本
clientUA := headers.Get("User-Agent")
if clientUA != "" && isNewerVersion(clientUA, cached.UserAgent) {
- // 更新user-agent
- cached.UserAgent = clientUA
- // 保存更新后的指纹
- _ = s.cache.SetFingerprint(ctx, accountID, cached)
- logger.LegacyPrintf("service.identity", "Updated fingerprint user-agent for account %d: %s", accountID, clientUA)
+ // 版本升级:merge 语义 — 仅更新请求中实际携带的字段,保留缓存值
+ // 避免缺失的头被硬编码默认值覆盖(如新 CLI 版本 + 旧 SDK 默认值的不一致)
+ mergeHeadersIntoFingerprint(cached, headers)
+ needWrite = true
+ logger.LegacyPrintf("service.identity", "Updated fingerprint for account %d: %s (merge update)", accountID, clientUA)
+ } else if time.Since(time.Unix(cached.UpdatedAt, 0)) > 24*time.Hour {
+ // 距上次写入超过24小时,续期TTL
+ needWrite = true
+ }
+
+ if needWrite {
+ cached.UpdatedAt = time.Now().Unix()
+ if err := s.cache.SetFingerprint(ctx, accountID, cached); err != nil {
+ logger.LegacyPrintf("service.identity", "Warning: failed to refresh fingerprint for account %d: %v", accountID, err)
+ }
}
return cached, nil
}
@@ -95,8 +108,9 @@ func (s *IdentityService) GetOrCreateFingerprint(ctx context.Context, accountID
// 生成随机ClientID
fp.ClientID = generateClientID()
+ fp.UpdatedAt = time.Now().Unix()
- // 保存到缓存(永不过期)
+ // 保存到缓存(7天TTL,每24小时自动续期)
if err := s.cache.SetFingerprint(ctx, accountID, fp); err != nil {
logger.LegacyPrintf("service.identity", "Warning: failed to cache fingerprint for account %d: %v", accountID, err)
}
@@ -127,6 +141,31 @@ func (s *IdentityService) createFingerprintFromHeaders(headers http.Header) *Fin
return fp
}
+// mergeHeadersIntoFingerprint 将请求头中实际存在的字段合并到现有指纹中(用于版本升级场景)
+// 关键语义:请求中有的字段 → 用新值覆盖;缺失的头 → 保留缓存中的已有值
+// 与 createFingerprintFromHeaders 的区别:后者用于首次创建,缺失头回退到 defaultFingerprint;
+// 本函数用于升级更新,缺失头保留缓存值,避免将已知的真实值退化为硬编码默认值
+func mergeHeadersIntoFingerprint(fp *Fingerprint, headers http.Header) {
+ // User-Agent:版本升级的触发条件,一定存在
+ if ua := headers.Get("User-Agent"); ua != "" {
+ fp.UserAgent = ua
+ }
+ // X-Stainless-* 头:仅在请求中实际携带时才更新,否则保留缓存值
+ mergeHeader(headers, "X-Stainless-Lang", &fp.StainlessLang)
+ mergeHeader(headers, "X-Stainless-Package-Version", &fp.StainlessPackageVersion)
+ mergeHeader(headers, "X-Stainless-OS", &fp.StainlessOS)
+ mergeHeader(headers, "X-Stainless-Arch", &fp.StainlessArch)
+ mergeHeader(headers, "X-Stainless-Runtime", &fp.StainlessRuntime)
+ mergeHeader(headers, "X-Stainless-Runtime-Version", &fp.StainlessRuntimeVersion)
+}
+
+// mergeHeader 如果请求头中存在该字段则更新目标值,否则保留原值
+func mergeHeader(headers http.Header, key string, target *string) {
+ if v := headers.Get(key); v != "" {
+ *target = v
+ }
+}
+
// getHeaderOrDefault 获取header值,如果不存在则返回默认值
func getHeaderOrDefault(headers http.Header, key, defaultValue string) string {
if v := headers.Get(key); v != "" {
@@ -371,8 +410,25 @@ func parseUserAgentVersion(ua string) (major, minor, patch int, ok bool) {
return major, minor, patch, true
}
+// extractProduct 提取 User-Agent 中 "/" 前的产品名
+// 例如:claude-cli/2.1.22 (external, cli) -> "claude-cli"
+func extractProduct(ua string) string {
+ if idx := strings.Index(ua, "/"); idx > 0 {
+ return strings.ToLower(ua[:idx])
+ }
+ return ""
+}
+
// isNewerVersion 比较版本号,判断newUA是否比cachedUA更新
+// 要求产品名一致(防止浏览器 UA 如 Mozilla/5.0 误判为更新版本)
func isNewerVersion(newUA, cachedUA string) bool {
+ // 校验产品名一致性
+ newProduct := extractProduct(newUA)
+ cachedProduct := extractProduct(cachedUA)
+ if newProduct == "" || cachedProduct == "" || newProduct != cachedProduct {
+ return false
+ }
+
newMajor, newMinor, newPatch, newOk := parseUserAgentVersion(newUA)
cachedMajor, cachedMinor, cachedPatch, cachedOk := parseUserAgentVersion(cachedUA)
diff --git a/backend/internal/service/model_rate_limit.go b/backend/internal/service/model_rate_limit.go
index ff4b5977..c45615cc 100644
--- a/backend/internal/service/model_rate_limit.go
+++ b/backend/internal/service/model_rate_limit.go
@@ -4,8 +4,6 @@ import (
"context"
"strings"
"time"
-
- "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
)
const modelRateLimitsKey = "model_rate_limits"
@@ -73,7 +71,7 @@ func resolveFinalAntigravityModelKey(ctx context.Context, account *Account, requ
return ""
}
// thinking 会影响 Antigravity 最终模型名(例如 claude-sonnet-4-5 -> claude-sonnet-4-5-thinking)
- if enabled, ok := ctx.Value(ctxkey.ThinkingEnabled).(bool); ok {
+ if enabled, ok := ThinkingEnabledFromContext(ctx); ok {
modelKey = applyThinkingModelSuffix(modelKey, enabled)
}
return modelKey
diff --git a/backend/internal/service/oauth_service.go b/backend/internal/service/oauth_service.go
index 6f6261d8..0931f9ce 100644
--- a/backend/internal/service/oauth_service.go
+++ b/backend/internal/service/oauth_service.go
@@ -12,7 +12,7 @@ import (
// OpenAIOAuthClient interface for OpenAI OAuth operations
type OpenAIOAuthClient interface {
- ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error)
+ ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL, clientID string) (*openai.TokenResponse, error)
RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error)
RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error)
}
diff --git a/backend/internal/service/oauth_service_test.go b/backend/internal/service/oauth_service_test.go
index 72de4b8c..78f39dc5 100644
--- a/backend/internal/service/oauth_service_test.go
+++ b/backend/internal/service/oauth_service_test.go
@@ -14,10 +14,10 @@ import (
// --- mock: ClaudeOAuthClient ---
type mockClaudeOAuthClient struct {
- getOrgUUIDFunc func(ctx context.Context, sessionKey, proxyURL string) (string, error)
- getAuthCodeFunc func(ctx context.Context, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL string) (string, error)
- exchangeCodeFunc func(ctx context.Context, code, codeVerifier, state, proxyURL string, isSetupToken bool) (*oauth.TokenResponse, error)
- refreshTokenFunc func(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error)
+ getOrgUUIDFunc func(ctx context.Context, sessionKey, proxyURL string) (string, error)
+ getAuthCodeFunc func(ctx context.Context, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL string) (string, error)
+ exchangeCodeFunc func(ctx context.Context, code, codeVerifier, state, proxyURL string, isSetupToken bool) (*oauth.TokenResponse, error)
+ refreshTokenFunc func(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error)
}
func (m *mockClaudeOAuthClient) GetOrganizationUUID(ctx context.Context, sessionKey, proxyURL string) (string, error) {
@@ -437,9 +437,9 @@ func TestOAuthService_RefreshAccountToken_NoRefreshToken(t *testing.T) {
// 无 refresh_token 的账号
account := &Account{
- ID: 1,
- Platform: PlatformAnthropic,
- Type: AccountTypeOAuth,
+ ID: 1,
+ Platform: PlatformAnthropic,
+ Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "some-token",
},
@@ -460,9 +460,9 @@ func TestOAuthService_RefreshAccountToken_EmptyRefreshToken(t *testing.T) {
defer svc.Stop()
account := &Account{
- ID: 2,
- Platform: PlatformAnthropic,
- Type: AccountTypeOAuth,
+ ID: 2,
+ Platform: PlatformAnthropic,
+ Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "some-token",
"refresh_token": "",
diff --git a/backend/internal/service/openai_account_scheduler.go b/backend/internal/service/openai_account_scheduler.go
new file mode 100644
index 00000000..99013ce5
--- /dev/null
+++ b/backend/internal/service/openai_account_scheduler.go
@@ -0,0 +1,909 @@
+package service
+
+import (
+ "container/heap"
+ "context"
+ "errors"
+ "hash/fnv"
+ "math"
+ "sort"
+ "strconv"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "time"
+)
+
+const (
+ openAIAccountScheduleLayerPreviousResponse = "previous_response_id"
+ openAIAccountScheduleLayerSessionSticky = "session_hash"
+ openAIAccountScheduleLayerLoadBalance = "load_balance"
+)
+
+type OpenAIAccountScheduleRequest struct {
+ GroupID *int64
+ SessionHash string
+ StickyAccountID int64
+ PreviousResponseID string
+ RequestedModel string
+ RequiredTransport OpenAIUpstreamTransport
+ ExcludedIDs map[int64]struct{}
+}
+
+type OpenAIAccountScheduleDecision struct {
+ Layer string
+ StickyPreviousHit bool
+ StickySessionHit bool
+ CandidateCount int
+ TopK int
+ LatencyMs int64
+ LoadSkew float64
+ SelectedAccountID int64
+ SelectedAccountType string
+}
+
+type OpenAIAccountSchedulerMetricsSnapshot struct {
+ SelectTotal int64
+ StickyPreviousHitTotal int64
+ StickySessionHitTotal int64
+ LoadBalanceSelectTotal int64
+ AccountSwitchTotal int64
+ SchedulerLatencyMsTotal int64
+ SchedulerLatencyMsAvg float64
+ StickyHitRatio float64
+ AccountSwitchRate float64
+ LoadSkewAvg float64
+ RuntimeStatsAccountCount int
+}
+
+type OpenAIAccountScheduler interface {
+ Select(ctx context.Context, req OpenAIAccountScheduleRequest) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error)
+ ReportResult(accountID int64, success bool, firstTokenMs *int)
+ ReportSwitch()
+ SnapshotMetrics() OpenAIAccountSchedulerMetricsSnapshot
+}
+
+type openAIAccountSchedulerMetrics struct {
+ selectTotal atomic.Int64
+ stickyPreviousHitTotal atomic.Int64
+ stickySessionHitTotal atomic.Int64
+ loadBalanceSelectTotal atomic.Int64
+ accountSwitchTotal atomic.Int64
+ latencyMsTotal atomic.Int64
+ loadSkewMilliTotal atomic.Int64
+}
+
+func (m *openAIAccountSchedulerMetrics) recordSelect(decision OpenAIAccountScheduleDecision) {
+ if m == nil {
+ return
+ }
+ m.selectTotal.Add(1)
+ m.latencyMsTotal.Add(decision.LatencyMs)
+ m.loadSkewMilliTotal.Add(int64(math.Round(decision.LoadSkew * 1000)))
+ if decision.StickyPreviousHit {
+ m.stickyPreviousHitTotal.Add(1)
+ }
+ if decision.StickySessionHit {
+ m.stickySessionHitTotal.Add(1)
+ }
+ if decision.Layer == openAIAccountScheduleLayerLoadBalance {
+ m.loadBalanceSelectTotal.Add(1)
+ }
+}
+
+func (m *openAIAccountSchedulerMetrics) recordSwitch() {
+ if m == nil {
+ return
+ }
+ m.accountSwitchTotal.Add(1)
+}
+
+type openAIAccountRuntimeStats struct {
+ accounts sync.Map
+ accountCount atomic.Int64
+}
+
+type openAIAccountRuntimeStat struct {
+ errorRateEWMABits atomic.Uint64
+ ttftEWMABits atomic.Uint64
+}
+
+func newOpenAIAccountRuntimeStats() *openAIAccountRuntimeStats {
+ return &openAIAccountRuntimeStats{}
+}
+
+func (s *openAIAccountRuntimeStats) loadOrCreate(accountID int64) *openAIAccountRuntimeStat {
+ if value, ok := s.accounts.Load(accountID); ok {
+ stat, _ := value.(*openAIAccountRuntimeStat)
+ if stat != nil {
+ return stat
+ }
+ }
+
+ stat := &openAIAccountRuntimeStat{}
+ stat.ttftEWMABits.Store(math.Float64bits(math.NaN()))
+ actual, loaded := s.accounts.LoadOrStore(accountID, stat)
+ if !loaded {
+ s.accountCount.Add(1)
+ return stat
+ }
+ existing, _ := actual.(*openAIAccountRuntimeStat)
+ if existing != nil {
+ return existing
+ }
+ return stat
+}
+
+func updateEWMAAtomic(target *atomic.Uint64, sample float64, alpha float64) {
+ for {
+ oldBits := target.Load()
+ oldValue := math.Float64frombits(oldBits)
+ newValue := alpha*sample + (1-alpha)*oldValue
+ if target.CompareAndSwap(oldBits, math.Float64bits(newValue)) {
+ return
+ }
+ }
+}
+
+func (s *openAIAccountRuntimeStats) report(accountID int64, success bool, firstTokenMs *int) {
+ if s == nil || accountID <= 0 {
+ return
+ }
+ const alpha = 0.2
+ stat := s.loadOrCreate(accountID)
+
+ errorSample := 1.0
+ if success {
+ errorSample = 0.0
+ }
+ updateEWMAAtomic(&stat.errorRateEWMABits, errorSample, alpha)
+
+ if firstTokenMs != nil && *firstTokenMs > 0 {
+ ttft := float64(*firstTokenMs)
+ ttftBits := math.Float64bits(ttft)
+ for {
+ oldBits := stat.ttftEWMABits.Load()
+ oldValue := math.Float64frombits(oldBits)
+ if math.IsNaN(oldValue) {
+ if stat.ttftEWMABits.CompareAndSwap(oldBits, ttftBits) {
+ break
+ }
+ continue
+ }
+ newValue := alpha*ttft + (1-alpha)*oldValue
+ if stat.ttftEWMABits.CompareAndSwap(oldBits, math.Float64bits(newValue)) {
+ break
+ }
+ }
+ }
+}
+
+func (s *openAIAccountRuntimeStats) snapshot(accountID int64) (errorRate float64, ttft float64, hasTTFT bool) {
+ if s == nil || accountID <= 0 {
+ return 0, 0, false
+ }
+ value, ok := s.accounts.Load(accountID)
+ if !ok {
+ return 0, 0, false
+ }
+ stat, _ := value.(*openAIAccountRuntimeStat)
+ if stat == nil {
+ return 0, 0, false
+ }
+ errorRate = clamp01(math.Float64frombits(stat.errorRateEWMABits.Load()))
+ ttftValue := math.Float64frombits(stat.ttftEWMABits.Load())
+ if math.IsNaN(ttftValue) {
+ return errorRate, 0, false
+ }
+ return errorRate, ttftValue, true
+}
+
+func (s *openAIAccountRuntimeStats) size() int {
+ if s == nil {
+ return 0
+ }
+ return int(s.accountCount.Load())
+}
+
+type defaultOpenAIAccountScheduler struct {
+ service *OpenAIGatewayService
+ metrics openAIAccountSchedulerMetrics
+ stats *openAIAccountRuntimeStats
+}
+
+func newDefaultOpenAIAccountScheduler(service *OpenAIGatewayService, stats *openAIAccountRuntimeStats) OpenAIAccountScheduler {
+ if stats == nil {
+ stats = newOpenAIAccountRuntimeStats()
+ }
+ return &defaultOpenAIAccountScheduler{
+ service: service,
+ stats: stats,
+ }
+}
+
+func (s *defaultOpenAIAccountScheduler) Select(
+ ctx context.Context,
+ req OpenAIAccountScheduleRequest,
+) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) {
+ decision := OpenAIAccountScheduleDecision{}
+ start := time.Now()
+ defer func() {
+ decision.LatencyMs = time.Since(start).Milliseconds()
+ s.metrics.recordSelect(decision)
+ }()
+
+ previousResponseID := strings.TrimSpace(req.PreviousResponseID)
+ if previousResponseID != "" {
+ selection, err := s.service.SelectAccountByPreviousResponseID(
+ ctx,
+ req.GroupID,
+ previousResponseID,
+ req.RequestedModel,
+ req.ExcludedIDs,
+ )
+ if err != nil {
+ return nil, decision, err
+ }
+ if selection != nil && selection.Account != nil {
+ if !s.isAccountTransportCompatible(selection.Account, req.RequiredTransport) {
+ selection = nil
+ }
+ }
+ if selection != nil && selection.Account != nil {
+ decision.Layer = openAIAccountScheduleLayerPreviousResponse
+ decision.StickyPreviousHit = true
+ decision.SelectedAccountID = selection.Account.ID
+ decision.SelectedAccountType = selection.Account.Type
+ if req.SessionHash != "" {
+ _ = s.service.BindStickySession(ctx, req.GroupID, req.SessionHash, selection.Account.ID)
+ }
+ return selection, decision, nil
+ }
+ }
+
+ selection, err := s.selectBySessionHash(ctx, req)
+ if err != nil {
+ return nil, decision, err
+ }
+ if selection != nil && selection.Account != nil {
+ decision.Layer = openAIAccountScheduleLayerSessionSticky
+ decision.StickySessionHit = true
+ decision.SelectedAccountID = selection.Account.ID
+ decision.SelectedAccountType = selection.Account.Type
+ return selection, decision, nil
+ }
+
+ selection, candidateCount, topK, loadSkew, err := s.selectByLoadBalance(ctx, req)
+ decision.Layer = openAIAccountScheduleLayerLoadBalance
+ decision.CandidateCount = candidateCount
+ decision.TopK = topK
+ decision.LoadSkew = loadSkew
+ if err != nil {
+ return nil, decision, err
+ }
+ if selection != nil && selection.Account != nil {
+ decision.SelectedAccountID = selection.Account.ID
+ decision.SelectedAccountType = selection.Account.Type
+ }
+ return selection, decision, nil
+}
+
+func (s *defaultOpenAIAccountScheduler) selectBySessionHash(
+ ctx context.Context,
+ req OpenAIAccountScheduleRequest,
+) (*AccountSelectionResult, error) {
+ sessionHash := strings.TrimSpace(req.SessionHash)
+ if sessionHash == "" || s == nil || s.service == nil || s.service.cache == nil {
+ return nil, nil
+ }
+
+ accountID := req.StickyAccountID
+ if accountID <= 0 {
+ var err error
+ accountID, err = s.service.getStickySessionAccountID(ctx, req.GroupID, sessionHash)
+ if err != nil || accountID <= 0 {
+ return nil, nil
+ }
+ }
+ if accountID <= 0 {
+ return nil, nil
+ }
+ if req.ExcludedIDs != nil {
+ if _, excluded := req.ExcludedIDs[accountID]; excluded {
+ return nil, nil
+ }
+ }
+
+ account, err := s.service.getSchedulableAccount(ctx, accountID)
+ if err != nil || account == nil {
+ _ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash)
+ return nil, nil
+ }
+ if shouldClearStickySession(account, req.RequestedModel) || !account.IsOpenAI() {
+ _ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash)
+ return nil, nil
+ }
+ if req.RequestedModel != "" && !account.IsModelSupported(req.RequestedModel) {
+ return nil, nil
+ }
+ if !s.isAccountTransportCompatible(account, req.RequiredTransport) {
+ _ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash)
+ return nil, nil
+ }
+
+ result, acquireErr := s.service.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
+ if acquireErr == nil && result.Acquired {
+ _ = s.service.refreshStickySessionTTL(ctx, req.GroupID, sessionHash, s.service.openAIWSSessionStickyTTL())
+ return &AccountSelectionResult{
+ Account: account,
+ Acquired: true,
+ ReleaseFunc: result.ReleaseFunc,
+ }, nil
+ }
+
+ cfg := s.service.schedulingConfig()
+ if s.service.concurrencyService != nil {
+ return &AccountSelectionResult{
+ Account: account,
+ WaitPlan: &AccountWaitPlan{
+ AccountID: accountID,
+ MaxConcurrency: account.Concurrency,
+ Timeout: cfg.StickySessionWaitTimeout,
+ MaxWaiting: cfg.StickySessionMaxWaiting,
+ },
+ }, nil
+ }
+ return nil, nil
+}
+
+type openAIAccountCandidateScore struct {
+ account *Account
+ loadInfo *AccountLoadInfo
+ score float64
+ errorRate float64
+ ttft float64
+ hasTTFT bool
+}
+
+type openAIAccountCandidateHeap []openAIAccountCandidateScore
+
+func (h openAIAccountCandidateHeap) Len() int {
+ return len(h)
+}
+
+func (h openAIAccountCandidateHeap) Less(i, j int) bool {
+ // 最小堆根节点保存“最差”候选,便于 O(log k) 维护 topK。
+ return isOpenAIAccountCandidateBetter(h[j], h[i])
+}
+
+func (h openAIAccountCandidateHeap) Swap(i, j int) {
+ h[i], h[j] = h[j], h[i]
+}
+
+func (h *openAIAccountCandidateHeap) Push(x any) {
+ candidate, ok := x.(openAIAccountCandidateScore)
+ if !ok {
+ panic("openAIAccountCandidateHeap: invalid element type")
+ }
+ *h = append(*h, candidate)
+}
+
+func (h *openAIAccountCandidateHeap) Pop() any {
+ old := *h
+ n := len(old)
+ last := old[n-1]
+ *h = old[:n-1]
+ return last
+}
+
+func isOpenAIAccountCandidateBetter(left openAIAccountCandidateScore, right openAIAccountCandidateScore) bool {
+ if left.score != right.score {
+ return left.score > right.score
+ }
+ if left.account.Priority != right.account.Priority {
+ return left.account.Priority < right.account.Priority
+ }
+ if left.loadInfo.LoadRate != right.loadInfo.LoadRate {
+ return left.loadInfo.LoadRate < right.loadInfo.LoadRate
+ }
+ if left.loadInfo.WaitingCount != right.loadInfo.WaitingCount {
+ return left.loadInfo.WaitingCount < right.loadInfo.WaitingCount
+ }
+ return left.account.ID < right.account.ID
+}
+
+func selectTopKOpenAICandidates(candidates []openAIAccountCandidateScore, topK int) []openAIAccountCandidateScore {
+ if len(candidates) == 0 {
+ return nil
+ }
+ if topK <= 0 {
+ topK = 1
+ }
+ if topK >= len(candidates) {
+ ranked := append([]openAIAccountCandidateScore(nil), candidates...)
+ sort.Slice(ranked, func(i, j int) bool {
+ return isOpenAIAccountCandidateBetter(ranked[i], ranked[j])
+ })
+ return ranked
+ }
+
+ best := make(openAIAccountCandidateHeap, 0, topK)
+ for _, candidate := range candidates {
+ if len(best) < topK {
+ heap.Push(&best, candidate)
+ continue
+ }
+ if isOpenAIAccountCandidateBetter(candidate, best[0]) {
+ best[0] = candidate
+ heap.Fix(&best, 0)
+ }
+ }
+
+ ranked := make([]openAIAccountCandidateScore, len(best))
+ copy(ranked, best)
+ sort.Slice(ranked, func(i, j int) bool {
+ return isOpenAIAccountCandidateBetter(ranked[i], ranked[j])
+ })
+ return ranked
+}
+
+type openAISelectionRNG struct {
+ state uint64
+}
+
+func newOpenAISelectionRNG(seed uint64) openAISelectionRNG {
+ if seed == 0 {
+ seed = 0x9e3779b97f4a7c15
+ }
+ return openAISelectionRNG{state: seed}
+}
+
+func (r *openAISelectionRNG) nextUint64() uint64 {
+ // xorshift64*
+ x := r.state
+ x ^= x >> 12
+ x ^= x << 25
+ x ^= x >> 27
+ r.state = x
+ return x * 2685821657736338717
+}
+
+func (r *openAISelectionRNG) nextFloat64() float64 {
+ // [0,1)
+ return float64(r.nextUint64()>>11) / (1 << 53)
+}
+
+func deriveOpenAISelectionSeed(req OpenAIAccountScheduleRequest) uint64 {
+ hasher := fnv.New64a()
+ writeValue := func(value string) {
+ trimmed := strings.TrimSpace(value)
+ if trimmed == "" {
+ return
+ }
+ _, _ = hasher.Write([]byte(trimmed))
+ _, _ = hasher.Write([]byte{0})
+ }
+
+ writeValue(req.SessionHash)
+ writeValue(req.PreviousResponseID)
+ writeValue(req.RequestedModel)
+ if req.GroupID != nil {
+ _, _ = hasher.Write([]byte(strconv.FormatInt(*req.GroupID, 10)))
+ }
+
+ seed := hasher.Sum64()
+ // 对“无会话锚点”的纯负载均衡请求引入时间熵,避免固定命中同一账号。
+ if strings.TrimSpace(req.SessionHash) == "" && strings.TrimSpace(req.PreviousResponseID) == "" {
+ seed ^= uint64(time.Now().UnixNano())
+ }
+ if seed == 0 {
+ seed = uint64(time.Now().UnixNano()) ^ 0x9e3779b97f4a7c15
+ }
+ return seed
+}
+
+func buildOpenAIWeightedSelectionOrder(
+ candidates []openAIAccountCandidateScore,
+ req OpenAIAccountScheduleRequest,
+) []openAIAccountCandidateScore {
+ if len(candidates) <= 1 {
+ return append([]openAIAccountCandidateScore(nil), candidates...)
+ }
+
+ pool := append([]openAIAccountCandidateScore(nil), candidates...)
+ weights := make([]float64, len(pool))
+ minScore := pool[0].score
+ for i := 1; i < len(pool); i++ {
+ if pool[i].score < minScore {
+ minScore = pool[i].score
+ }
+ }
+ for i := range pool {
+ // 将 top-K 分值平移到正区间,避免“单一最高分账号”长期垄断。
+ weight := (pool[i].score - minScore) + 1.0
+ if math.IsNaN(weight) || math.IsInf(weight, 0) || weight <= 0 {
+ weight = 1.0
+ }
+ weights[i] = weight
+ }
+
+ order := make([]openAIAccountCandidateScore, 0, len(pool))
+ rng := newOpenAISelectionRNG(deriveOpenAISelectionSeed(req))
+ for len(pool) > 0 {
+ total := 0.0
+ for _, w := range weights {
+ total += w
+ }
+
+ selectedIdx := 0
+ if total > 0 {
+ r := rng.nextFloat64() * total
+ acc := 0.0
+ for i, w := range weights {
+ acc += w
+ if r <= acc {
+ selectedIdx = i
+ break
+ }
+ }
+ } else {
+ selectedIdx = int(rng.nextUint64() % uint64(len(pool)))
+ }
+
+ order = append(order, pool[selectedIdx])
+ pool = append(pool[:selectedIdx], pool[selectedIdx+1:]...)
+ weights = append(weights[:selectedIdx], weights[selectedIdx+1:]...)
+ }
+ return order
+}
+
+func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
+ ctx context.Context,
+ req OpenAIAccountScheduleRequest,
+) (*AccountSelectionResult, int, int, float64, error) {
+ accounts, err := s.service.listSchedulableAccounts(ctx, req.GroupID)
+ if err != nil {
+ return nil, 0, 0, 0, err
+ }
+ if len(accounts) == 0 {
+ return nil, 0, 0, 0, errors.New("no available OpenAI accounts")
+ }
+
+ filtered := make([]*Account, 0, len(accounts))
+ loadReq := make([]AccountWithConcurrency, 0, len(accounts))
+ for i := range accounts {
+ account := &accounts[i]
+ if req.ExcludedIDs != nil {
+ if _, excluded := req.ExcludedIDs[account.ID]; excluded {
+ continue
+ }
+ }
+ if !account.IsSchedulable() || !account.IsOpenAI() {
+ continue
+ }
+ if req.RequestedModel != "" && !account.IsModelSupported(req.RequestedModel) {
+ continue
+ }
+ if !s.isAccountTransportCompatible(account, req.RequiredTransport) {
+ continue
+ }
+ filtered = append(filtered, account)
+ loadReq = append(loadReq, AccountWithConcurrency{
+ ID: account.ID,
+ MaxConcurrency: account.Concurrency,
+ })
+ }
+ if len(filtered) == 0 {
+ return nil, 0, 0, 0, errors.New("no available OpenAI accounts")
+ }
+
+ loadMap := map[int64]*AccountLoadInfo{}
+ if s.service.concurrencyService != nil {
+ if batchLoad, loadErr := s.service.concurrencyService.GetAccountsLoadBatch(ctx, loadReq); loadErr == nil {
+ loadMap = batchLoad
+ }
+ }
+
+ minPriority, maxPriority := filtered[0].Priority, filtered[0].Priority
+ maxWaiting := 1
+ loadRateSum := 0.0
+ loadRateSumSquares := 0.0
+ minTTFT, maxTTFT := 0.0, 0.0
+ hasTTFTSample := false
+ candidates := make([]openAIAccountCandidateScore, 0, len(filtered))
+ for _, account := range filtered {
+ loadInfo := loadMap[account.ID]
+ if loadInfo == nil {
+ loadInfo = &AccountLoadInfo{AccountID: account.ID}
+ }
+ if account.Priority < minPriority {
+ minPriority = account.Priority
+ }
+ if account.Priority > maxPriority {
+ maxPriority = account.Priority
+ }
+ if loadInfo.WaitingCount > maxWaiting {
+ maxWaiting = loadInfo.WaitingCount
+ }
+ errorRate, ttft, hasTTFT := s.stats.snapshot(account.ID)
+ if hasTTFT && ttft > 0 {
+ if !hasTTFTSample {
+ minTTFT, maxTTFT = ttft, ttft
+ hasTTFTSample = true
+ } else {
+ if ttft < minTTFT {
+ minTTFT = ttft
+ }
+ if ttft > maxTTFT {
+ maxTTFT = ttft
+ }
+ }
+ }
+ loadRate := float64(loadInfo.LoadRate)
+ loadRateSum += loadRate
+ loadRateSumSquares += loadRate * loadRate
+ candidates = append(candidates, openAIAccountCandidateScore{
+ account: account,
+ loadInfo: loadInfo,
+ errorRate: errorRate,
+ ttft: ttft,
+ hasTTFT: hasTTFT,
+ })
+ }
+ loadSkew := calcLoadSkewByMoments(loadRateSum, loadRateSumSquares, len(candidates))
+
+ weights := s.service.openAIWSSchedulerWeights()
+ for i := range candidates {
+ item := &candidates[i]
+ priorityFactor := 1.0
+ if maxPriority > minPriority {
+ priorityFactor = 1 - float64(item.account.Priority-minPriority)/float64(maxPriority-minPriority)
+ }
+ loadFactor := 1 - clamp01(float64(item.loadInfo.LoadRate)/100.0)
+ queueFactor := 1 - clamp01(float64(item.loadInfo.WaitingCount)/float64(maxWaiting))
+ errorFactor := 1 - clamp01(item.errorRate)
+ ttftFactor := 0.5
+ if item.hasTTFT && hasTTFTSample && maxTTFT > minTTFT {
+ ttftFactor = 1 - clamp01((item.ttft-minTTFT)/(maxTTFT-minTTFT))
+ }
+
+ item.score = weights.Priority*priorityFactor +
+ weights.Load*loadFactor +
+ weights.Queue*queueFactor +
+ weights.ErrorRate*errorFactor +
+ weights.TTFT*ttftFactor
+ }
+
+ topK := s.service.openAIWSLBTopK()
+ if topK > len(candidates) {
+ topK = len(candidates)
+ }
+ if topK <= 0 {
+ topK = 1
+ }
+ rankedCandidates := selectTopKOpenAICandidates(candidates, topK)
+ selectionOrder := buildOpenAIWeightedSelectionOrder(rankedCandidates, req)
+
+ for i := 0; i < len(selectionOrder); i++ {
+ candidate := selectionOrder[i]
+ result, acquireErr := s.service.tryAcquireAccountSlot(ctx, candidate.account.ID, candidate.account.Concurrency)
+ if acquireErr != nil {
+ return nil, len(candidates), topK, loadSkew, acquireErr
+ }
+ if result != nil && result.Acquired {
+ if req.SessionHash != "" {
+ _ = s.service.BindStickySession(ctx, req.GroupID, req.SessionHash, candidate.account.ID)
+ }
+ return &AccountSelectionResult{
+ Account: candidate.account,
+ Acquired: true,
+ ReleaseFunc: result.ReleaseFunc,
+ }, len(candidates), topK, loadSkew, nil
+ }
+ }
+
+ cfg := s.service.schedulingConfig()
+ candidate := selectionOrder[0]
+ return &AccountSelectionResult{
+ Account: candidate.account,
+ WaitPlan: &AccountWaitPlan{
+ AccountID: candidate.account.ID,
+ MaxConcurrency: candidate.account.Concurrency,
+ Timeout: cfg.FallbackWaitTimeout,
+ MaxWaiting: cfg.FallbackMaxWaiting,
+ },
+ }, len(candidates), topK, loadSkew, nil
+}
+
+func (s *defaultOpenAIAccountScheduler) isAccountTransportCompatible(account *Account, requiredTransport OpenAIUpstreamTransport) bool {
+ // HTTP 入站可回退到 HTTP 线路,不需要在账号选择阶段做传输协议强过滤。
+ if requiredTransport == OpenAIUpstreamTransportAny || requiredTransport == OpenAIUpstreamTransportHTTPSSE {
+ return true
+ }
+ if s == nil || s.service == nil || account == nil {
+ return false
+ }
+ return s.service.getOpenAIWSProtocolResolver().Resolve(account).Transport == requiredTransport
+}
+
+func (s *defaultOpenAIAccountScheduler) ReportResult(accountID int64, success bool, firstTokenMs *int) {
+ if s == nil || s.stats == nil {
+ return
+ }
+ s.stats.report(accountID, success, firstTokenMs)
+}
+
+func (s *defaultOpenAIAccountScheduler) ReportSwitch() {
+ if s == nil {
+ return
+ }
+ s.metrics.recordSwitch()
+}
+
+func (s *defaultOpenAIAccountScheduler) SnapshotMetrics() OpenAIAccountSchedulerMetricsSnapshot {
+ if s == nil {
+ return OpenAIAccountSchedulerMetricsSnapshot{}
+ }
+
+ selectTotal := s.metrics.selectTotal.Load()
+ prevHit := s.metrics.stickyPreviousHitTotal.Load()
+ sessionHit := s.metrics.stickySessionHitTotal.Load()
+ switchTotal := s.metrics.accountSwitchTotal.Load()
+ latencyTotal := s.metrics.latencyMsTotal.Load()
+ loadSkewTotal := s.metrics.loadSkewMilliTotal.Load()
+
+ snapshot := OpenAIAccountSchedulerMetricsSnapshot{
+ SelectTotal: selectTotal,
+ StickyPreviousHitTotal: prevHit,
+ StickySessionHitTotal: sessionHit,
+ LoadBalanceSelectTotal: s.metrics.loadBalanceSelectTotal.Load(),
+ AccountSwitchTotal: switchTotal,
+ SchedulerLatencyMsTotal: latencyTotal,
+ RuntimeStatsAccountCount: s.stats.size(),
+ }
+ if selectTotal > 0 {
+ snapshot.SchedulerLatencyMsAvg = float64(latencyTotal) / float64(selectTotal)
+ snapshot.StickyHitRatio = float64(prevHit+sessionHit) / float64(selectTotal)
+ snapshot.AccountSwitchRate = float64(switchTotal) / float64(selectTotal)
+ snapshot.LoadSkewAvg = float64(loadSkewTotal) / 1000 / float64(selectTotal)
+ }
+ return snapshot
+}
+
+func (s *OpenAIGatewayService) getOpenAIAccountScheduler() OpenAIAccountScheduler {
+ if s == nil {
+ return nil
+ }
+ s.openaiSchedulerOnce.Do(func() {
+ if s.openaiAccountStats == nil {
+ s.openaiAccountStats = newOpenAIAccountRuntimeStats()
+ }
+ if s.openaiScheduler == nil {
+ s.openaiScheduler = newDefaultOpenAIAccountScheduler(s, s.openaiAccountStats)
+ }
+ })
+ return s.openaiScheduler
+}
+
+func (s *OpenAIGatewayService) SelectAccountWithScheduler(
+ ctx context.Context,
+ groupID *int64,
+ previousResponseID string,
+ sessionHash string,
+ requestedModel string,
+ excludedIDs map[int64]struct{},
+ requiredTransport OpenAIUpstreamTransport,
+) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) {
+ decision := OpenAIAccountScheduleDecision{}
+ scheduler := s.getOpenAIAccountScheduler()
+ if scheduler == nil {
+ selection, err := s.SelectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, excludedIDs)
+ decision.Layer = openAIAccountScheduleLayerLoadBalance
+ return selection, decision, err
+ }
+
+ var stickyAccountID int64
+ if sessionHash != "" && s.cache != nil {
+ if accountID, err := s.getStickySessionAccountID(ctx, groupID, sessionHash); err == nil && accountID > 0 {
+ stickyAccountID = accountID
+ }
+ }
+
+ return scheduler.Select(ctx, OpenAIAccountScheduleRequest{
+ GroupID: groupID,
+ SessionHash: sessionHash,
+ StickyAccountID: stickyAccountID,
+ PreviousResponseID: previousResponseID,
+ RequestedModel: requestedModel,
+ RequiredTransport: requiredTransport,
+ ExcludedIDs: excludedIDs,
+ })
+}
+
+func (s *OpenAIGatewayService) ReportOpenAIAccountScheduleResult(accountID int64, success bool, firstTokenMs *int) {
+ scheduler := s.getOpenAIAccountScheduler()
+ if scheduler == nil {
+ return
+ }
+ scheduler.ReportResult(accountID, success, firstTokenMs)
+}
+
+func (s *OpenAIGatewayService) RecordOpenAIAccountSwitch() {
+ scheduler := s.getOpenAIAccountScheduler()
+ if scheduler == nil {
+ return
+ }
+ scheduler.ReportSwitch()
+}
+
+func (s *OpenAIGatewayService) SnapshotOpenAIAccountSchedulerMetrics() OpenAIAccountSchedulerMetricsSnapshot {
+ scheduler := s.getOpenAIAccountScheduler()
+ if scheduler == nil {
+ return OpenAIAccountSchedulerMetricsSnapshot{}
+ }
+ return scheduler.SnapshotMetrics()
+}
+
+func (s *OpenAIGatewayService) openAIWSSessionStickyTTL() time.Duration {
+ if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.StickySessionTTLSeconds > 0 {
+ return time.Duration(s.cfg.Gateway.OpenAIWS.StickySessionTTLSeconds) * time.Second
+ }
+ return openaiStickySessionTTL
+}
+
+func (s *OpenAIGatewayService) openAIWSLBTopK() int {
+ if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.LBTopK > 0 {
+ return s.cfg.Gateway.OpenAIWS.LBTopK
+ }
+ return 7
+}
+
+func (s *OpenAIGatewayService) openAIWSSchedulerWeights() GatewayOpenAIWSSchedulerScoreWeightsView {
+ if s != nil && s.cfg != nil {
+ return GatewayOpenAIWSSchedulerScoreWeightsView{
+ Priority: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority,
+ Load: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load,
+ Queue: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue,
+ ErrorRate: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate,
+ TTFT: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT,
+ }
+ }
+ return GatewayOpenAIWSSchedulerScoreWeightsView{
+ Priority: 1.0,
+ Load: 1.0,
+ Queue: 0.7,
+ ErrorRate: 0.8,
+ TTFT: 0.5,
+ }
+}
+
+type GatewayOpenAIWSSchedulerScoreWeightsView struct {
+ Priority float64
+ Load float64
+ Queue float64
+ ErrorRate float64
+ TTFT float64
+}
+
+func clamp01(value float64) float64 {
+ switch {
+ case value < 0:
+ return 0
+ case value > 1:
+ return 1
+ default:
+ return value
+ }
+}
+
+func calcLoadSkewByMoments(sum float64, sumSquares float64, count int) float64 {
+ if count <= 1 {
+ return 0
+ }
+ mean := sum / float64(count)
+ variance := sumSquares/float64(count) - mean*mean
+ if variance < 0 {
+ variance = 0
+ }
+ return math.Sqrt(variance)
+}
diff --git a/backend/internal/service/openai_account_scheduler_benchmark_test.go b/backend/internal/service/openai_account_scheduler_benchmark_test.go
new file mode 100644
index 00000000..897be5b0
--- /dev/null
+++ b/backend/internal/service/openai_account_scheduler_benchmark_test.go
@@ -0,0 +1,83 @@
+package service
+
+import (
+ "sort"
+ "testing"
+)
+
+func buildOpenAISchedulerBenchmarkCandidates(size int) []openAIAccountCandidateScore {
+ if size <= 0 {
+ return nil
+ }
+ candidates := make([]openAIAccountCandidateScore, 0, size)
+ for i := 0; i < size; i++ {
+ accountID := int64(10_000 + i)
+ candidates = append(candidates, openAIAccountCandidateScore{
+ account: &Account{
+ ID: accountID,
+ Priority: i % 7,
+ },
+ loadInfo: &AccountLoadInfo{
+ AccountID: accountID,
+ LoadRate: (i * 17) % 100,
+ WaitingCount: (i * 11) % 13,
+ },
+ score: float64((i*29)%1000) / 100,
+ errorRate: float64((i * 5) % 100 / 100),
+ ttft: float64(30 + (i*3)%500),
+ hasTTFT: i%3 != 0,
+ })
+ }
+ return candidates
+}
+
+func selectTopKOpenAICandidatesBySortBenchmark(candidates []openAIAccountCandidateScore, topK int) []openAIAccountCandidateScore {
+ if len(candidates) == 0 {
+ return nil
+ }
+ if topK <= 0 {
+ topK = 1
+ }
+ ranked := append([]openAIAccountCandidateScore(nil), candidates...)
+ sort.Slice(ranked, func(i, j int) bool {
+ return isOpenAIAccountCandidateBetter(ranked[i], ranked[j])
+ })
+ if topK > len(ranked) {
+ topK = len(ranked)
+ }
+ return ranked[:topK]
+}
+
+func BenchmarkOpenAIAccountSchedulerSelectTopK(b *testing.B) {
+ cases := []struct {
+ name string
+ size int
+ topK int
+ }{
+ {name: "n_16_k_3", size: 16, topK: 3},
+ {name: "n_64_k_3", size: 64, topK: 3},
+ {name: "n_256_k_5", size: 256, topK: 5},
+ }
+
+ for _, tc := range cases {
+ candidates := buildOpenAISchedulerBenchmarkCandidates(tc.size)
+ b.Run(tc.name+"/heap_topk", func(b *testing.B) {
+ b.ReportAllocs()
+ for i := 0; i < b.N; i++ {
+ result := selectTopKOpenAICandidates(candidates, tc.topK)
+ if len(result) == 0 {
+ b.Fatal("unexpected empty result")
+ }
+ }
+ })
+ b.Run(tc.name+"/full_sort", func(b *testing.B) {
+ b.ReportAllocs()
+ for i := 0; i < b.N; i++ {
+ result := selectTopKOpenAICandidatesBySortBenchmark(candidates, tc.topK)
+ if len(result) == 0 {
+ b.Fatal("unexpected empty result")
+ }
+ }
+ })
+ }
+}
diff --git a/backend/internal/service/openai_account_scheduler_test.go b/backend/internal/service/openai_account_scheduler_test.go
new file mode 100644
index 00000000..7f6f1b66
--- /dev/null
+++ b/backend/internal/service/openai_account_scheduler_test.go
@@ -0,0 +1,841 @@
+package service
+
+import (
+ "context"
+ "fmt"
+ "math"
+ "sync"
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/stretchr/testify/require"
+)
+
+func TestOpenAIGatewayService_SelectAccountWithScheduler_PreviousResponseSticky(t *testing.T) {
+ ctx := context.Background()
+ groupID := int64(9)
+ account := Account{
+ ID: 1001,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 2,
+ Extra: map[string]any{
+ "openai_apikey_responses_websockets_v2_enabled": true,
+ },
+ }
+ cache := &stubGatewayCache{}
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.Enabled = true
+ cfg.Gateway.OpenAIWS.OAuthEnabled = true
+ cfg.Gateway.OpenAIWS.APIKeyEnabled = true
+ cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 1800
+ cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600
+
+ svc := &OpenAIGatewayService{
+ accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
+ cache: cache,
+ cfg: cfg,
+ concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
+ }
+
+ store := svc.getOpenAIWSStateStore()
+ require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_001", account.ID, time.Hour))
+
+ selection, decision, err := svc.SelectAccountWithScheduler(
+ ctx,
+ &groupID,
+ "resp_prev_001",
+ "session_hash_001",
+ "gpt-5.1",
+ nil,
+ OpenAIUpstreamTransportAny,
+ )
+ require.NoError(t, err)
+ require.NotNil(t, selection)
+ require.NotNil(t, selection.Account)
+ require.Equal(t, account.ID, selection.Account.ID)
+ require.Equal(t, openAIAccountScheduleLayerPreviousResponse, decision.Layer)
+ require.True(t, decision.StickyPreviousHit)
+ require.Equal(t, account.ID, cache.sessionBindings["openai:session_hash_001"])
+ if selection.ReleaseFunc != nil {
+ selection.ReleaseFunc()
+ }
+}
+
+func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionSticky(t *testing.T) {
+ ctx := context.Background()
+ groupID := int64(10)
+ account := Account{
+ ID: 2001,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ }
+ cache := &stubGatewayCache{
+ sessionBindings: map[string]int64{
+ "openai:session_hash_abc": account.ID,
+ },
+ }
+
+ svc := &OpenAIGatewayService{
+ accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
+ cache: cache,
+ cfg: &config.Config{},
+ concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
+ }
+
+ selection, decision, err := svc.SelectAccountWithScheduler(
+ ctx,
+ &groupID,
+ "",
+ "session_hash_abc",
+ "gpt-5.1",
+ nil,
+ OpenAIUpstreamTransportAny,
+ )
+ require.NoError(t, err)
+ require.NotNil(t, selection)
+ require.NotNil(t, selection.Account)
+ require.Equal(t, account.ID, selection.Account.ID)
+ require.Equal(t, openAIAccountScheduleLayerSessionSticky, decision.Layer)
+ require.True(t, decision.StickySessionHit)
+ if selection.ReleaseFunc != nil {
+ selection.ReleaseFunc()
+ }
+}
+
+func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyBusyKeepsSticky(t *testing.T) {
+ ctx := context.Background()
+ groupID := int64(10100)
+ accounts := []Account{
+ {
+ ID: 21001,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Priority: 0,
+ },
+ {
+ ID: 21002,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Priority: 9,
+ },
+ }
+ cache := &stubGatewayCache{
+ sessionBindings: map[string]int64{
+ "openai:session_hash_sticky_busy": 21001,
+ },
+ }
+ cfg := &config.Config{}
+ cfg.Gateway.Scheduling.StickySessionMaxWaiting = 2
+ cfg.Gateway.Scheduling.StickySessionWaitTimeout = 45 * time.Second
+ cfg.Gateway.OpenAIWS.Enabled = true
+ cfg.Gateway.OpenAIWS.APIKeyEnabled = true
+ cfg.Gateway.OpenAIWS.OAuthEnabled = true
+ cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
+
+ concurrencyCache := stubConcurrencyCache{
+ acquireResults: map[int64]bool{
+ 21001: false, // sticky 账号已满
+ 21002: true, // 若回退负载均衡会命中该账号(本测试要求不能切换)
+ },
+ waitCounts: map[int64]int{
+ 21001: 999,
+ },
+ loadMap: map[int64]*AccountLoadInfo{
+ 21001: {AccountID: 21001, LoadRate: 90, WaitingCount: 9},
+ 21002: {AccountID: 21002, LoadRate: 1, WaitingCount: 0},
+ },
+ }
+
+ svc := &OpenAIGatewayService{
+ accountRepo: stubOpenAIAccountRepo{accounts: accounts},
+ cache: cache,
+ cfg: cfg,
+ concurrencyService: NewConcurrencyService(concurrencyCache),
+ }
+
+ selection, decision, err := svc.SelectAccountWithScheduler(
+ ctx,
+ &groupID,
+ "",
+ "session_hash_sticky_busy",
+ "gpt-5.1",
+ nil,
+ OpenAIUpstreamTransportAny,
+ )
+ require.NoError(t, err)
+ require.NotNil(t, selection)
+ require.NotNil(t, selection.Account)
+ require.Equal(t, int64(21001), selection.Account.ID, "busy sticky account should remain selected")
+ require.False(t, selection.Acquired)
+ require.NotNil(t, selection.WaitPlan)
+ require.Equal(t, int64(21001), selection.WaitPlan.AccountID)
+ require.Equal(t, openAIAccountScheduleLayerSessionSticky, decision.Layer)
+ require.True(t, decision.StickySessionHit)
+}
+
+func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionSticky_ForceHTTP(t *testing.T) {
+ ctx := context.Background()
+ groupID := int64(1010)
+ account := Account{
+ ID: 2101,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Extra: map[string]any{
+ "openai_ws_force_http": true,
+ },
+ }
+ cache := &stubGatewayCache{
+ sessionBindings: map[string]int64{
+ "openai:session_hash_force_http": account.ID,
+ },
+ }
+
+ svc := &OpenAIGatewayService{
+ accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
+ cache: cache,
+ cfg: &config.Config{},
+ concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
+ }
+
+ selection, decision, err := svc.SelectAccountWithScheduler(
+ ctx,
+ &groupID,
+ "",
+ "session_hash_force_http",
+ "gpt-5.1",
+ nil,
+ OpenAIUpstreamTransportAny,
+ )
+ require.NoError(t, err)
+ require.NotNil(t, selection)
+ require.NotNil(t, selection.Account)
+ require.Equal(t, account.ID, selection.Account.ID)
+ require.Equal(t, openAIAccountScheduleLayerSessionSticky, decision.Layer)
+ require.True(t, decision.StickySessionHit)
+ if selection.ReleaseFunc != nil {
+ selection.ReleaseFunc()
+ }
+}
+
+func TestOpenAIGatewayService_SelectAccountWithScheduler_RequiredWSV2_SkipsStickyHTTPAccount(t *testing.T) {
+ ctx := context.Background()
+ groupID := int64(1011)
+ accounts := []Account{
+ {
+ ID: 2201,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Priority: 0,
+ },
+ {
+ ID: 2202,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Priority: 5,
+ Extra: map[string]any{
+ "openai_apikey_responses_websockets_v2_enabled": true,
+ },
+ },
+ }
+ cache := &stubGatewayCache{
+ sessionBindings: map[string]int64{
+ "openai:session_hash_ws_only": 2201,
+ },
+ }
+ cfg := newOpenAIWSV2TestConfig()
+
+ // 构造“HTTP-only 账号负载更低”的场景,验证 required transport 会强制过滤。
+ concurrencyCache := stubConcurrencyCache{
+ loadMap: map[int64]*AccountLoadInfo{
+ 2201: {AccountID: 2201, LoadRate: 0, WaitingCount: 0},
+ 2202: {AccountID: 2202, LoadRate: 90, WaitingCount: 5},
+ },
+ }
+
+ svc := &OpenAIGatewayService{
+ accountRepo: stubOpenAIAccountRepo{accounts: accounts},
+ cache: cache,
+ cfg: cfg,
+ concurrencyService: NewConcurrencyService(concurrencyCache),
+ }
+
+ selection, decision, err := svc.SelectAccountWithScheduler(
+ ctx,
+ &groupID,
+ "",
+ "session_hash_ws_only",
+ "gpt-5.1",
+ nil,
+ OpenAIUpstreamTransportResponsesWebsocketV2,
+ )
+ require.NoError(t, err)
+ require.NotNil(t, selection)
+ require.NotNil(t, selection.Account)
+ require.Equal(t, int64(2202), selection.Account.ID)
+ require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
+ require.False(t, decision.StickySessionHit)
+ require.Equal(t, 1, decision.CandidateCount)
+ if selection.ReleaseFunc != nil {
+ selection.ReleaseFunc()
+ }
+}
+
+func TestOpenAIGatewayService_SelectAccountWithScheduler_RequiredWSV2_NoAvailableAccount(t *testing.T) {
+ ctx := context.Background()
+ groupID := int64(1012)
+ accounts := []Account{
+ {
+ ID: 2301,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ },
+ }
+
+ svc := &OpenAIGatewayService{
+ accountRepo: stubOpenAIAccountRepo{accounts: accounts},
+ cache: &stubGatewayCache{},
+ cfg: newOpenAIWSV2TestConfig(),
+ concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
+ }
+
+ selection, decision, err := svc.SelectAccountWithScheduler(
+ ctx,
+ &groupID,
+ "",
+ "",
+ "gpt-5.1",
+ nil,
+ OpenAIUpstreamTransportResponsesWebsocketV2,
+ )
+ require.Error(t, err)
+ require.Nil(t, selection)
+ require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
+ require.Equal(t, 0, decision.CandidateCount)
+}
+
+func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceTopKFallback(t *testing.T) {
+ ctx := context.Background()
+ groupID := int64(11)
+ accounts := []Account{
+ {
+ ID: 3001,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Priority: 0,
+ },
+ {
+ ID: 3002,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Priority: 0,
+ },
+ {
+ ID: 3003,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Priority: 0,
+ },
+ }
+
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.LBTopK = 2
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority = 0.4
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load = 1.0
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = 1.0
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 0.2
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 0.1
+
+ concurrencyCache := stubConcurrencyCache{
+ loadMap: map[int64]*AccountLoadInfo{
+ 3001: {AccountID: 3001, LoadRate: 95, WaitingCount: 8},
+ 3002: {AccountID: 3002, LoadRate: 20, WaitingCount: 1},
+ 3003: {AccountID: 3003, LoadRate: 10, WaitingCount: 0},
+ },
+ acquireResults: map[int64]bool{
+ 3003: false, // top1 失败,必须回退到 top-K 的下一候选
+ 3002: true,
+ },
+ }
+
+ svc := &OpenAIGatewayService{
+ accountRepo: stubOpenAIAccountRepo{accounts: accounts},
+ cache: &stubGatewayCache{},
+ cfg: cfg,
+ concurrencyService: NewConcurrencyService(concurrencyCache),
+ }
+
+ selection, decision, err := svc.SelectAccountWithScheduler(
+ ctx,
+ &groupID,
+ "",
+ "",
+ "gpt-5.1",
+ nil,
+ OpenAIUpstreamTransportAny,
+ )
+ require.NoError(t, err)
+ require.NotNil(t, selection)
+ require.NotNil(t, selection.Account)
+ require.Equal(t, int64(3002), selection.Account.ID)
+ require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
+ require.Equal(t, 3, decision.CandidateCount)
+ require.Equal(t, 2, decision.TopK)
+ require.Greater(t, decision.LoadSkew, 0.0)
+ if selection.ReleaseFunc != nil {
+ selection.ReleaseFunc()
+ }
+}
+
+func TestOpenAIGatewayService_OpenAIAccountSchedulerMetrics(t *testing.T) {
+ ctx := context.Background()
+ groupID := int64(12)
+ account := Account{
+ ID: 4001,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ }
+ cache := &stubGatewayCache{
+ sessionBindings: map[string]int64{
+ "openai:session_hash_metrics": account.ID,
+ },
+ }
+ svc := &OpenAIGatewayService{
+ accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
+ cache: cache,
+ cfg: &config.Config{},
+ concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
+ }
+
+ selection, _, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_metrics", "gpt-5.1", nil, OpenAIUpstreamTransportAny)
+ require.NoError(t, err)
+ require.NotNil(t, selection)
+ svc.ReportOpenAIAccountScheduleResult(account.ID, true, intPtrForTest(120))
+ svc.RecordOpenAIAccountSwitch()
+
+ snapshot := svc.SnapshotOpenAIAccountSchedulerMetrics()
+ require.GreaterOrEqual(t, snapshot.SelectTotal, int64(1))
+ require.GreaterOrEqual(t, snapshot.StickySessionHitTotal, int64(1))
+ require.GreaterOrEqual(t, snapshot.AccountSwitchTotal, int64(1))
+ require.GreaterOrEqual(t, snapshot.SchedulerLatencyMsAvg, float64(0))
+ require.GreaterOrEqual(t, snapshot.StickyHitRatio, 0.0)
+ require.GreaterOrEqual(t, snapshot.RuntimeStatsAccountCount, 1)
+}
+
+func intPtrForTest(v int) *int {
+ return &v
+}
+
+func TestOpenAIAccountRuntimeStats_ReportAndSnapshot(t *testing.T) {
+ stats := newOpenAIAccountRuntimeStats()
+ stats.report(1001, true, nil)
+ firstTTFT := 100
+ stats.report(1001, false, &firstTTFT)
+ secondTTFT := 200
+ stats.report(1001, false, &secondTTFT)
+
+ errorRate, ttft, hasTTFT := stats.snapshot(1001)
+ require.True(t, hasTTFT)
+ require.InDelta(t, 0.36, errorRate, 1e-9)
+ require.InDelta(t, 120.0, ttft, 1e-9)
+ require.Equal(t, 1, stats.size())
+}
+
+func TestOpenAIAccountRuntimeStats_ReportConcurrent(t *testing.T) {
+ stats := newOpenAIAccountRuntimeStats()
+
+ const (
+ accountCount = 4
+ workers = 16
+ iterations = 800
+ )
+ var wg sync.WaitGroup
+ wg.Add(workers)
+ for worker := 0; worker < workers; worker++ {
+ worker := worker
+ go func() {
+ defer wg.Done()
+ for i := 0; i < iterations; i++ {
+ accountID := int64(i%accountCount + 1)
+ success := (i+worker)%3 != 0
+ ttft := 80 + (i+worker)%40
+ stats.report(accountID, success, &ttft)
+ }
+ }()
+ }
+ wg.Wait()
+
+ require.Equal(t, accountCount, stats.size())
+ for accountID := int64(1); accountID <= accountCount; accountID++ {
+ errorRate, ttft, hasTTFT := stats.snapshot(accountID)
+ require.GreaterOrEqual(t, errorRate, 0.0)
+ require.LessOrEqual(t, errorRate, 1.0)
+ require.True(t, hasTTFT)
+ require.Greater(t, ttft, 0.0)
+ }
+}
+
+func TestSelectTopKOpenAICandidates(t *testing.T) {
+ candidates := []openAIAccountCandidateScore{
+ {
+ account: &Account{ID: 11, Priority: 2},
+ loadInfo: &AccountLoadInfo{LoadRate: 10, WaitingCount: 1},
+ score: 10.0,
+ },
+ {
+ account: &Account{ID: 12, Priority: 1},
+ loadInfo: &AccountLoadInfo{LoadRate: 20, WaitingCount: 1},
+ score: 9.5,
+ },
+ {
+ account: &Account{ID: 13, Priority: 1},
+ loadInfo: &AccountLoadInfo{LoadRate: 30, WaitingCount: 0},
+ score: 10.0,
+ },
+ {
+ account: &Account{ID: 14, Priority: 0},
+ loadInfo: &AccountLoadInfo{LoadRate: 40, WaitingCount: 0},
+ score: 8.0,
+ },
+ }
+
+ top2 := selectTopKOpenAICandidates(candidates, 2)
+ require.Len(t, top2, 2)
+ require.Equal(t, int64(13), top2[0].account.ID)
+ require.Equal(t, int64(11), top2[1].account.ID)
+
+ topAll := selectTopKOpenAICandidates(candidates, 8)
+ require.Len(t, topAll, len(candidates))
+ require.Equal(t, int64(13), topAll[0].account.ID)
+ require.Equal(t, int64(11), topAll[1].account.ID)
+ require.Equal(t, int64(12), topAll[2].account.ID)
+ require.Equal(t, int64(14), topAll[3].account.ID)
+}
+
+func TestBuildOpenAIWeightedSelectionOrder_DeterministicBySessionSeed(t *testing.T) {
+ candidates := []openAIAccountCandidateScore{
+ {
+ account: &Account{ID: 101},
+ loadInfo: &AccountLoadInfo{LoadRate: 10, WaitingCount: 0},
+ score: 4.2,
+ },
+ {
+ account: &Account{ID: 102},
+ loadInfo: &AccountLoadInfo{LoadRate: 30, WaitingCount: 1},
+ score: 3.5,
+ },
+ {
+ account: &Account{ID: 103},
+ loadInfo: &AccountLoadInfo{LoadRate: 50, WaitingCount: 2},
+ score: 2.1,
+ },
+ }
+ req := OpenAIAccountScheduleRequest{
+ GroupID: int64PtrForTest(99),
+ SessionHash: "session_seed_fixed",
+ RequestedModel: "gpt-5.1",
+ }
+
+ first := buildOpenAIWeightedSelectionOrder(candidates, req)
+ second := buildOpenAIWeightedSelectionOrder(candidates, req)
+ require.Len(t, first, len(candidates))
+ require.Len(t, second, len(candidates))
+ for i := range first {
+ require.Equal(t, first[i].account.ID, second[i].account.ID)
+ }
+}
+
+func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceDistributesAcrossSessions(t *testing.T) {
+ ctx := context.Background()
+ groupID := int64(15)
+ accounts := []Account{
+ {
+ ID: 5101,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 3,
+ Priority: 0,
+ },
+ {
+ ID: 5102,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 3,
+ Priority: 0,
+ },
+ {
+ ID: 5103,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 3,
+ Priority: 0,
+ },
+ }
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.LBTopK = 3
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority = 1
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load = 1
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = 1
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 1
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 1
+
+ concurrencyCache := stubConcurrencyCache{
+ loadMap: map[int64]*AccountLoadInfo{
+ 5101: {AccountID: 5101, LoadRate: 20, WaitingCount: 1},
+ 5102: {AccountID: 5102, LoadRate: 20, WaitingCount: 1},
+ 5103: {AccountID: 5103, LoadRate: 20, WaitingCount: 1},
+ },
+ }
+ svc := &OpenAIGatewayService{
+ accountRepo: stubOpenAIAccountRepo{accounts: accounts},
+ cache: &stubGatewayCache{sessionBindings: map[string]int64{}},
+ cfg: cfg,
+ concurrencyService: NewConcurrencyService(concurrencyCache),
+ }
+
+ selected := make(map[int64]int, len(accounts))
+ for i := 0; i < 60; i++ {
+ sessionHash := fmt.Sprintf("session_hash_lb_%d", i)
+ selection, decision, err := svc.SelectAccountWithScheduler(
+ ctx,
+ &groupID,
+ "",
+ sessionHash,
+ "gpt-5.1",
+ nil,
+ OpenAIUpstreamTransportAny,
+ )
+ require.NoError(t, err)
+ require.NotNil(t, selection)
+ require.NotNil(t, selection.Account)
+ require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
+ selected[selection.Account.ID]++
+ if selection.ReleaseFunc != nil {
+ selection.ReleaseFunc()
+ }
+ }
+
+ // 多 session 应该能打散到多个账号,避免“恒定单账号命中”。
+ require.GreaterOrEqual(t, len(selected), 2)
+}
+
+func TestDeriveOpenAISelectionSeed_NoAffinityAddsEntropy(t *testing.T) {
+ req := OpenAIAccountScheduleRequest{
+ RequestedModel: "gpt-5.1",
+ }
+ seed1 := deriveOpenAISelectionSeed(req)
+ time.Sleep(1 * time.Millisecond)
+ seed2 := deriveOpenAISelectionSeed(req)
+ require.NotZero(t, seed1)
+ require.NotZero(t, seed2)
+ require.NotEqual(t, seed1, seed2)
+}
+
+func TestBuildOpenAIWeightedSelectionOrder_HandlesInvalidScores(t *testing.T) {
+ candidates := []openAIAccountCandidateScore{
+ {
+ account: &Account{ID: 901},
+ loadInfo: &AccountLoadInfo{LoadRate: 5, WaitingCount: 0},
+ score: math.NaN(),
+ },
+ {
+ account: &Account{ID: 902},
+ loadInfo: &AccountLoadInfo{LoadRate: 5, WaitingCount: 0},
+ score: math.Inf(1),
+ },
+ {
+ account: &Account{ID: 903},
+ loadInfo: &AccountLoadInfo{LoadRate: 5, WaitingCount: 0},
+ score: -1,
+ },
+ }
+ req := OpenAIAccountScheduleRequest{
+ SessionHash: "seed_invalid_scores",
+ }
+
+ order := buildOpenAIWeightedSelectionOrder(candidates, req)
+ require.Len(t, order, len(candidates))
+ seen := map[int64]struct{}{}
+ for _, item := range order {
+ seen[item.account.ID] = struct{}{}
+ }
+ require.Len(t, seen, len(candidates))
+}
+
+func TestOpenAISelectionRNG_SeedZeroStillWorks(t *testing.T) {
+ rng := newOpenAISelectionRNG(0)
+ v1 := rng.nextUint64()
+ v2 := rng.nextUint64()
+ require.NotEqual(t, v1, v2)
+ require.GreaterOrEqual(t, rng.nextFloat64(), 0.0)
+ require.Less(t, rng.nextFloat64(), 1.0)
+}
+
+func TestOpenAIAccountCandidateHeap_PushPopAndInvalidType(t *testing.T) {
+ h := openAIAccountCandidateHeap{}
+ h.Push(openAIAccountCandidateScore{
+ account: &Account{ID: 7001},
+ loadInfo: &AccountLoadInfo{LoadRate: 0, WaitingCount: 0},
+ score: 1.0,
+ })
+ require.Equal(t, 1, h.Len())
+ popped, ok := h.Pop().(openAIAccountCandidateScore)
+ require.True(t, ok)
+ require.Equal(t, int64(7001), popped.account.ID)
+ require.Equal(t, 0, h.Len())
+
+ require.Panics(t, func() {
+ h.Push("bad_element_type")
+ })
+}
+
+func TestClamp01_AllBranches(t *testing.T) {
+ require.Equal(t, 0.0, clamp01(-0.2))
+ require.Equal(t, 1.0, clamp01(1.3))
+ require.Equal(t, 0.5, clamp01(0.5))
+}
+
+func TestCalcLoadSkewByMoments_Branches(t *testing.T) {
+ require.Equal(t, 0.0, calcLoadSkewByMoments(1, 1, 1))
+ // variance < 0 分支:sumSquares/count - mean^2 为负值时应钳制为 0。
+ require.Equal(t, 0.0, calcLoadSkewByMoments(1, 0, 2))
+ require.GreaterOrEqual(t, calcLoadSkewByMoments(6, 20, 3), 0.0)
+}
+
+func TestDefaultOpenAIAccountScheduler_ReportSwitchAndSnapshot(t *testing.T) {
+ schedulerAny := newDefaultOpenAIAccountScheduler(&OpenAIGatewayService{}, nil)
+ scheduler, ok := schedulerAny.(*defaultOpenAIAccountScheduler)
+ require.True(t, ok)
+
+ ttft := 100
+ scheduler.ReportResult(1001, true, &ttft)
+ scheduler.ReportSwitch()
+ scheduler.metrics.recordSelect(OpenAIAccountScheduleDecision{
+ Layer: openAIAccountScheduleLayerLoadBalance,
+ LatencyMs: 8,
+ LoadSkew: 0.5,
+ StickyPreviousHit: true,
+ })
+ scheduler.metrics.recordSelect(OpenAIAccountScheduleDecision{
+ Layer: openAIAccountScheduleLayerSessionSticky,
+ LatencyMs: 6,
+ LoadSkew: 0.2,
+ StickySessionHit: true,
+ })
+
+ snapshot := scheduler.SnapshotMetrics()
+ require.Equal(t, int64(2), snapshot.SelectTotal)
+ require.Equal(t, int64(1), snapshot.StickyPreviousHitTotal)
+ require.Equal(t, int64(1), snapshot.StickySessionHitTotal)
+ require.Equal(t, int64(1), snapshot.LoadBalanceSelectTotal)
+ require.Equal(t, int64(1), snapshot.AccountSwitchTotal)
+ require.Greater(t, snapshot.SchedulerLatencyMsAvg, 0.0)
+ require.Greater(t, snapshot.StickyHitRatio, 0.0)
+ require.Greater(t, snapshot.LoadSkewAvg, 0.0)
+}
+
+func TestOpenAIGatewayService_SchedulerWrappersAndDefaults(t *testing.T) {
+ svc := &OpenAIGatewayService{}
+ ttft := 120
+ svc.ReportOpenAIAccountScheduleResult(10, true, &ttft)
+ svc.RecordOpenAIAccountSwitch()
+ snapshot := svc.SnapshotOpenAIAccountSchedulerMetrics()
+ require.GreaterOrEqual(t, snapshot.AccountSwitchTotal, int64(1))
+ require.Equal(t, 7, svc.openAIWSLBTopK())
+ require.Equal(t, openaiStickySessionTTL, svc.openAIWSSessionStickyTTL())
+
+ defaultWeights := svc.openAIWSSchedulerWeights()
+ require.Equal(t, 1.0, defaultWeights.Priority)
+ require.Equal(t, 1.0, defaultWeights.Load)
+ require.Equal(t, 0.7, defaultWeights.Queue)
+ require.Equal(t, 0.8, defaultWeights.ErrorRate)
+ require.Equal(t, 0.5, defaultWeights.TTFT)
+
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.LBTopK = 9
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 180
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority = 0.2
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load = 0.3
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = 0.4
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 0.5
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 0.6
+ svcWithCfg := &OpenAIGatewayService{cfg: cfg}
+
+ require.Equal(t, 9, svcWithCfg.openAIWSLBTopK())
+ require.Equal(t, 180*time.Second, svcWithCfg.openAIWSSessionStickyTTL())
+ customWeights := svcWithCfg.openAIWSSchedulerWeights()
+ require.Equal(t, 0.2, customWeights.Priority)
+ require.Equal(t, 0.3, customWeights.Load)
+ require.Equal(t, 0.4, customWeights.Queue)
+ require.Equal(t, 0.5, customWeights.ErrorRate)
+ require.Equal(t, 0.6, customWeights.TTFT)
+}
+
+func TestDefaultOpenAIAccountScheduler_IsAccountTransportCompatible_Branches(t *testing.T) {
+ scheduler := &defaultOpenAIAccountScheduler{}
+ require.True(t, scheduler.isAccountTransportCompatible(nil, OpenAIUpstreamTransportAny))
+ require.True(t, scheduler.isAccountTransportCompatible(nil, OpenAIUpstreamTransportHTTPSSE))
+ require.False(t, scheduler.isAccountTransportCompatible(nil, OpenAIUpstreamTransportResponsesWebsocketV2))
+
+ cfg := newOpenAIWSV2TestConfig()
+ scheduler.service = &OpenAIGatewayService{cfg: cfg}
+ account := &Account{
+ ID: 8801,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Extra: map[string]any{
+ "openai_apikey_responses_websockets_v2_enabled": true,
+ },
+ }
+ require.True(t, scheduler.isAccountTransportCompatible(account, OpenAIUpstreamTransportResponsesWebsocketV2))
+}
+
+func int64PtrForTest(v int64) *int64 {
+ return &v
+}
diff --git a/backend/internal/service/openai_client_transport.go b/backend/internal/service/openai_client_transport.go
new file mode 100644
index 00000000..c9cf3246
--- /dev/null
+++ b/backend/internal/service/openai_client_transport.go
@@ -0,0 +1,71 @@
+package service
+
+import (
+ "strings"
+
+ "github.com/gin-gonic/gin"
+)
+
+// OpenAIClientTransport 表示客户端入站协议类型。
+type OpenAIClientTransport string
+
+const (
+ OpenAIClientTransportUnknown OpenAIClientTransport = ""
+ OpenAIClientTransportHTTP OpenAIClientTransport = "http"
+ OpenAIClientTransportWS OpenAIClientTransport = "ws"
+)
+
+const openAIClientTransportContextKey = "openai_client_transport"
+
+// SetOpenAIClientTransport 标记当前请求的客户端入站协议。
+func SetOpenAIClientTransport(c *gin.Context, transport OpenAIClientTransport) {
+ if c == nil {
+ return
+ }
+ normalized := normalizeOpenAIClientTransport(transport)
+ if normalized == OpenAIClientTransportUnknown {
+ return
+ }
+ c.Set(openAIClientTransportContextKey, string(normalized))
+}
+
+// GetOpenAIClientTransport 读取当前请求的客户端入站协议。
+func GetOpenAIClientTransport(c *gin.Context) OpenAIClientTransport {
+ if c == nil {
+ return OpenAIClientTransportUnknown
+ }
+ raw, ok := c.Get(openAIClientTransportContextKey)
+ if !ok || raw == nil {
+ return OpenAIClientTransportUnknown
+ }
+
+ switch v := raw.(type) {
+ case OpenAIClientTransport:
+ return normalizeOpenAIClientTransport(v)
+ case string:
+ return normalizeOpenAIClientTransport(OpenAIClientTransport(v))
+ default:
+ return OpenAIClientTransportUnknown
+ }
+}
+
+func normalizeOpenAIClientTransport(transport OpenAIClientTransport) OpenAIClientTransport {
+ switch strings.ToLower(strings.TrimSpace(string(transport))) {
+ case string(OpenAIClientTransportHTTP), "http_sse", "sse":
+ return OpenAIClientTransportHTTP
+ case string(OpenAIClientTransportWS), "websocket":
+ return OpenAIClientTransportWS
+ default:
+ return OpenAIClientTransportUnknown
+ }
+}
+
+func resolveOpenAIWSDecisionByClientTransport(
+ decision OpenAIWSProtocolDecision,
+ clientTransport OpenAIClientTransport,
+) OpenAIWSProtocolDecision {
+ if clientTransport == OpenAIClientTransportHTTP {
+ return openAIWSHTTPDecision("client_protocol_http")
+ }
+ return decision
+}
diff --git a/backend/internal/service/openai_client_transport_test.go b/backend/internal/service/openai_client_transport_test.go
new file mode 100644
index 00000000..ef90e614
--- /dev/null
+++ b/backend/internal/service/openai_client_transport_test.go
@@ -0,0 +1,107 @@
+package service
+
+import (
+ "net/http/httptest"
+ "testing"
+
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+)
+
+func TestOpenAIClientTransport_SetAndGet(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ w := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(w)
+
+ require.Equal(t, OpenAIClientTransportUnknown, GetOpenAIClientTransport(c))
+
+ SetOpenAIClientTransport(c, OpenAIClientTransportHTTP)
+ require.Equal(t, OpenAIClientTransportHTTP, GetOpenAIClientTransport(c))
+
+ SetOpenAIClientTransport(c, OpenAIClientTransportWS)
+ require.Equal(t, OpenAIClientTransportWS, GetOpenAIClientTransport(c))
+}
+
+func TestOpenAIClientTransport_GetNormalizesRawContextValue(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ tests := []struct {
+ name string
+ rawValue any
+ want OpenAIClientTransport
+ }{
+ {
+ name: "type_value_ws",
+ rawValue: OpenAIClientTransportWS,
+ want: OpenAIClientTransportWS,
+ },
+ {
+ name: "http_sse_alias",
+ rawValue: "http_sse",
+ want: OpenAIClientTransportHTTP,
+ },
+ {
+ name: "sse_alias",
+ rawValue: "sSe",
+ want: OpenAIClientTransportHTTP,
+ },
+ {
+ name: "websocket_alias",
+ rawValue: "WebSocket",
+ want: OpenAIClientTransportWS,
+ },
+ {
+ name: "invalid_string",
+ rawValue: "tcp",
+ want: OpenAIClientTransportUnknown,
+ },
+ {
+ name: "invalid_type",
+ rawValue: 123,
+ want: OpenAIClientTransportUnknown,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ w := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(w)
+ c.Set(openAIClientTransportContextKey, tt.rawValue)
+ require.Equal(t, tt.want, GetOpenAIClientTransport(c))
+ })
+ }
+}
+
+func TestOpenAIClientTransport_NilAndUnknownInput(t *testing.T) {
+ SetOpenAIClientTransport(nil, OpenAIClientTransportHTTP)
+ require.Equal(t, OpenAIClientTransportUnknown, GetOpenAIClientTransport(nil))
+
+ gin.SetMode(gin.TestMode)
+ w := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(w)
+
+ SetOpenAIClientTransport(c, OpenAIClientTransportUnknown)
+ _, exists := c.Get(openAIClientTransportContextKey)
+ require.False(t, exists)
+
+ SetOpenAIClientTransport(c, OpenAIClientTransport(" "))
+ _, exists = c.Get(openAIClientTransportContextKey)
+ require.False(t, exists)
+}
+
+func TestResolveOpenAIWSDecisionByClientTransport(t *testing.T) {
+ base := OpenAIWSProtocolDecision{
+ Transport: OpenAIUpstreamTransportResponsesWebsocketV2,
+ Reason: "ws_v2_enabled",
+ }
+
+ httpDecision := resolveOpenAIWSDecisionByClientTransport(base, OpenAIClientTransportHTTP)
+ require.Equal(t, OpenAIUpstreamTransportHTTPSSE, httpDecision.Transport)
+ require.Equal(t, "client_protocol_http", httpDecision.Reason)
+
+ wsDecision := resolveOpenAIWSDecisionByClientTransport(base, OpenAIClientTransportWS)
+ require.Equal(t, base, wsDecision)
+
+ unknownDecision := resolveOpenAIWSDecisionByClientTransport(base, OpenAIClientTransportUnknown)
+ require.Equal(t, base, unknownDecision)
+}
diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go
index f26ce03f..f624d92a 100644
--- a/backend/internal/service/openai_gateway_service.go
+++ b/backend/internal/service/openai_gateway_service.go
@@ -10,10 +10,12 @@ import (
"errors"
"fmt"
"io"
+ "math/rand"
"net/http"
"sort"
"strconv"
"strings"
+ "sync"
"sync/atomic"
"time"
@@ -34,35 +36,46 @@ const (
// OpenAI Platform API for API Key accounts (fallback)
openaiPlatformAPIURL = "https://api.openai.com/v1/responses"
openaiStickySessionTTL = time.Hour // 粘性会话TTL
- codexCLIUserAgent = "codex_cli_rs/0.98.0"
+ codexCLIUserAgent = "codex_cli_rs/0.104.0"
// codex_cli_only 拒绝时单个请求头日志长度上限(字符)
codexCLIOnlyHeaderValueMaxBytes = 256
// OpenAIParsedRequestBodyKey 缓存 handler 侧已解析的请求体,避免重复解析。
OpenAIParsedRequestBodyKey = "openai_parsed_request_body"
+ // OpenAI WS Mode 失败后的重连次数上限(不含首次尝试)。
+ // 与 Codex 客户端保持一致:失败后最多重连 5 次。
+ openAIWSReconnectRetryLimit = 5
+ // OpenAI WS Mode 重连退避默认值(可由配置覆盖)。
+ openAIWSRetryBackoffInitialDefault = 120 * time.Millisecond
+ openAIWSRetryBackoffMaxDefault = 2 * time.Second
+ openAIWSRetryJitterRatioDefault = 0.2
)
// OpenAI allowed headers whitelist (for non-passthrough).
var openaiAllowedHeaders = map[string]bool{
- "accept-language": true,
- "content-type": true,
- "conversation_id": true,
- "user-agent": true,
- "originator": true,
- "session_id": true,
+ "accept-language": true,
+ "content-type": true,
+ "conversation_id": true,
+ "user-agent": true,
+ "originator": true,
+ "session_id": true,
+ "x-codex-turn-state": true,
+ "x-codex-turn-metadata": true,
}
// OpenAI passthrough allowed headers whitelist.
// 透传模式下仅放行这些低风险请求头,避免将非标准/环境噪声头传给上游触发风控。
var openaiPassthroughAllowedHeaders = map[string]bool{
- "accept": true,
- "accept-language": true,
- "content-type": true,
- "conversation_id": true,
- "openai-beta": true,
- "user-agent": true,
- "originator": true,
- "session_id": true,
+ "accept": true,
+ "accept-language": true,
+ "content-type": true,
+ "conversation_id": true,
+ "openai-beta": true,
+ "user-agent": true,
+ "originator": true,
+ "session_id": true,
+ "x-codex-turn-state": true,
+ "x-codex-turn-metadata": true,
}
// codex_cli_only 拒绝时记录的请求头白名单(仅用于诊断日志,不参与上游透传)
@@ -196,10 +209,40 @@ type OpenAIForwardResult struct {
// Stored for usage records display; nil means not provided / not applicable.
ReasoningEffort *string
Stream bool
+ OpenAIWSMode bool
Duration time.Duration
FirstTokenMs *int
}
+type OpenAIWSRetryMetricsSnapshot struct {
+ RetryAttemptsTotal int64 `json:"retry_attempts_total"`
+ RetryBackoffMsTotal int64 `json:"retry_backoff_ms_total"`
+ RetryExhaustedTotal int64 `json:"retry_exhausted_total"`
+ NonRetryableFastFallbackTotal int64 `json:"non_retryable_fast_fallback_total"`
+}
+
+type OpenAICompatibilityFallbackMetricsSnapshot struct {
+ SessionHashLegacyReadFallbackTotal int64 `json:"session_hash_legacy_read_fallback_total"`
+ SessionHashLegacyReadFallbackHit int64 `json:"session_hash_legacy_read_fallback_hit"`
+ SessionHashLegacyDualWriteTotal int64 `json:"session_hash_legacy_dual_write_total"`
+ SessionHashLegacyReadHitRate float64 `json:"session_hash_legacy_read_hit_rate"`
+
+ MetadataLegacyFallbackIsMaxTokensOneHaikuTotal int64 `json:"metadata_legacy_fallback_is_max_tokens_one_haiku_total"`
+ MetadataLegacyFallbackThinkingEnabledTotal int64 `json:"metadata_legacy_fallback_thinking_enabled_total"`
+ MetadataLegacyFallbackPrefetchedStickyAccount int64 `json:"metadata_legacy_fallback_prefetched_sticky_account_total"`
+ MetadataLegacyFallbackPrefetchedStickyGroup int64 `json:"metadata_legacy_fallback_prefetched_sticky_group_total"`
+ MetadataLegacyFallbackSingleAccountRetryTotal int64 `json:"metadata_legacy_fallback_single_account_retry_total"`
+ MetadataLegacyFallbackAccountSwitchCountTotal int64 `json:"metadata_legacy_fallback_account_switch_count_total"`
+ MetadataLegacyFallbackTotal int64 `json:"metadata_legacy_fallback_total"`
+}
+
+type openAIWSRetryMetrics struct {
+ retryAttempts atomic.Int64
+ retryBackoffMs atomic.Int64
+ retryExhausted atomic.Int64
+ nonRetryableFastFallback atomic.Int64
+}
+
// OpenAIGatewayService handles OpenAI API gateway operations
type OpenAIGatewayService struct {
accountRepo AccountRepository
@@ -218,6 +261,19 @@ type OpenAIGatewayService struct {
deferredService *DeferredService
openAITokenProvider *OpenAITokenProvider
toolCorrector *CodexToolCorrector
+ openaiWSResolver OpenAIWSProtocolResolver
+
+ openaiWSPoolOnce sync.Once
+ openaiWSStateStoreOnce sync.Once
+ openaiSchedulerOnce sync.Once
+ openaiWSPool *openAIWSConnPool
+ openaiWSStateStore OpenAIWSStateStore
+ openaiScheduler OpenAIAccountScheduler
+ openaiAccountStats *openAIAccountRuntimeStats
+
+ openaiWSFallbackUntil sync.Map // key: int64(accountID), value: time.Time
+ openaiWSRetryMetrics openAIWSRetryMetrics
+ responseHeaderFilter *responseheaders.CompiledHeaderFilter
}
// NewOpenAIGatewayService creates a new OpenAIGatewayService
@@ -237,24 +293,61 @@ func NewOpenAIGatewayService(
deferredService *DeferredService,
openAITokenProvider *OpenAITokenProvider,
) *OpenAIGatewayService {
- return &OpenAIGatewayService{
- accountRepo: accountRepo,
- usageLogRepo: usageLogRepo,
- userRepo: userRepo,
- userSubRepo: userSubRepo,
- cache: cache,
- cfg: cfg,
- codexDetector: NewOpenAICodexClientRestrictionDetector(cfg),
- schedulerSnapshot: schedulerSnapshot,
- concurrencyService: concurrencyService,
- billingService: billingService,
- rateLimitService: rateLimitService,
- billingCacheService: billingCacheService,
- httpUpstream: httpUpstream,
- deferredService: deferredService,
- openAITokenProvider: openAITokenProvider,
- toolCorrector: NewCodexToolCorrector(),
+ svc := &OpenAIGatewayService{
+ accountRepo: accountRepo,
+ usageLogRepo: usageLogRepo,
+ userRepo: userRepo,
+ userSubRepo: userSubRepo,
+ cache: cache,
+ cfg: cfg,
+ codexDetector: NewOpenAICodexClientRestrictionDetector(cfg),
+ schedulerSnapshot: schedulerSnapshot,
+ concurrencyService: concurrencyService,
+ billingService: billingService,
+ rateLimitService: rateLimitService,
+ billingCacheService: billingCacheService,
+ httpUpstream: httpUpstream,
+ deferredService: deferredService,
+ openAITokenProvider: openAITokenProvider,
+ toolCorrector: NewCodexToolCorrector(),
+ openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
+ responseHeaderFilter: compileResponseHeaderFilter(cfg),
}
+ svc.logOpenAIWSModeBootstrap()
+ return svc
+}
+
+// CloseOpenAIWSPool 关闭 OpenAI WebSocket 连接池的后台 worker 和空闲连接。
+// 应在应用优雅关闭时调用。
+func (s *OpenAIGatewayService) CloseOpenAIWSPool() {
+ if s != nil && s.openaiWSPool != nil {
+ s.openaiWSPool.Close()
+ }
+}
+
+func (s *OpenAIGatewayService) logOpenAIWSModeBootstrap() {
+ if s == nil || s.cfg == nil {
+ return
+ }
+ wsCfg := s.cfg.Gateway.OpenAIWS
+ logOpenAIWSModeInfo(
+ "bootstrap enabled=%v oauth_enabled=%v apikey_enabled=%v force_http=%v responses_websockets_v2=%v responses_websockets=%v payload_log_sample_rate=%.3f event_flush_batch_size=%d event_flush_interval_ms=%d prewarm_cooldown_ms=%d retry_backoff_initial_ms=%d retry_backoff_max_ms=%d retry_jitter_ratio=%.3f retry_total_budget_ms=%d ws_read_limit_bytes=%d",
+ wsCfg.Enabled,
+ wsCfg.OAuthEnabled,
+ wsCfg.APIKeyEnabled,
+ wsCfg.ForceHTTP,
+ wsCfg.ResponsesWebsocketsV2,
+ wsCfg.ResponsesWebsockets,
+ wsCfg.PayloadLogSampleRate,
+ wsCfg.EventFlushBatchSize,
+ wsCfg.EventFlushIntervalMS,
+ wsCfg.PrewarmCooldownMS,
+ wsCfg.RetryBackoffInitialMS,
+ wsCfg.RetryBackoffMaxMS,
+ wsCfg.RetryJitterRatio,
+ wsCfg.RetryTotalBudgetMS,
+ openAIWSMessageReadLimitBytes,
+ )
}
func (s *OpenAIGatewayService) getCodexClientRestrictionDetector() CodexClientRestrictionDetector {
@@ -268,6 +361,317 @@ func (s *OpenAIGatewayService) getCodexClientRestrictionDetector() CodexClientRe
return NewOpenAICodexClientRestrictionDetector(cfg)
}
+func (s *OpenAIGatewayService) getOpenAIWSProtocolResolver() OpenAIWSProtocolResolver {
+ if s != nil && s.openaiWSResolver != nil {
+ return s.openaiWSResolver
+ }
+ var cfg *config.Config
+ if s != nil {
+ cfg = s.cfg
+ }
+ return NewOpenAIWSProtocolResolver(cfg)
+}
+
+func classifyOpenAIWSReconnectReason(err error) (string, bool) {
+ if err == nil {
+ return "", false
+ }
+ var fallbackErr *openAIWSFallbackError
+ if !errors.As(err, &fallbackErr) || fallbackErr == nil {
+ return "", false
+ }
+ reason := strings.TrimSpace(fallbackErr.Reason)
+ if reason == "" {
+ return "", false
+ }
+
+ baseReason := strings.TrimPrefix(reason, "prewarm_")
+
+ switch baseReason {
+ case "policy_violation",
+ "message_too_big",
+ "upgrade_required",
+ "ws_unsupported",
+ "auth_failed",
+ "previous_response_not_found":
+ return reason, false
+ }
+
+ switch baseReason {
+ case "read_event",
+ "write_request",
+ "write",
+ "acquire_timeout",
+ "acquire_conn",
+ "conn_queue_full",
+ "dial_failed",
+ "upstream_5xx",
+ "event_error",
+ "error_event",
+ "upstream_error_event",
+ "ws_connection_limit_reached",
+ "missing_final_response":
+ return reason, true
+ default:
+ return reason, false
+ }
+}
+
+func resolveOpenAIWSFallbackErrorResponse(err error) (statusCode int, errType string, clientMessage string, upstreamMessage string, ok bool) {
+ if err == nil {
+ return 0, "", "", "", false
+ }
+ var fallbackErr *openAIWSFallbackError
+ if !errors.As(err, &fallbackErr) || fallbackErr == nil {
+ return 0, "", "", "", false
+ }
+
+ reason := strings.TrimSpace(fallbackErr.Reason)
+ reason = strings.TrimPrefix(reason, "prewarm_")
+ if reason == "" {
+ return 0, "", "", "", false
+ }
+
+ var dialErr *openAIWSDialError
+ if fallbackErr.Err != nil && errors.As(fallbackErr.Err, &dialErr) && dialErr != nil {
+ if dialErr.StatusCode > 0 {
+ statusCode = dialErr.StatusCode
+ }
+ if dialErr.Err != nil {
+ upstreamMessage = sanitizeUpstreamErrorMessage(strings.TrimSpace(dialErr.Err.Error()))
+ }
+ }
+
+ switch reason {
+ case "previous_response_not_found":
+ if statusCode == 0 {
+ statusCode = http.StatusBadRequest
+ }
+ errType = "invalid_request_error"
+ if upstreamMessage == "" {
+ upstreamMessage = "previous response not found"
+ }
+ case "upgrade_required":
+ if statusCode == 0 {
+ statusCode = http.StatusUpgradeRequired
+ }
+ case "ws_unsupported":
+ if statusCode == 0 {
+ statusCode = http.StatusBadRequest
+ }
+ case "auth_failed":
+ if statusCode == 0 {
+ statusCode = http.StatusUnauthorized
+ }
+ case "upstream_rate_limited":
+ if statusCode == 0 {
+ statusCode = http.StatusTooManyRequests
+ }
+ default:
+ if statusCode == 0 {
+ return 0, "", "", "", false
+ }
+ }
+
+ if upstreamMessage == "" && fallbackErr.Err != nil {
+ upstreamMessage = sanitizeUpstreamErrorMessage(strings.TrimSpace(fallbackErr.Err.Error()))
+ }
+ if upstreamMessage == "" {
+ switch reason {
+ case "upgrade_required":
+ upstreamMessage = "upstream websocket upgrade required"
+ case "ws_unsupported":
+ upstreamMessage = "upstream websocket not supported"
+ case "auth_failed":
+ upstreamMessage = "upstream authentication failed"
+ case "upstream_rate_limited":
+ upstreamMessage = "upstream rate limit exceeded, please retry later"
+ default:
+ upstreamMessage = "Upstream request failed"
+ }
+ }
+
+ if errType == "" {
+ if statusCode == http.StatusTooManyRequests {
+ errType = "rate_limit_error"
+ } else {
+ errType = "upstream_error"
+ }
+ }
+ clientMessage = upstreamMessage
+ return statusCode, errType, clientMessage, upstreamMessage, true
+}
+
+func (s *OpenAIGatewayService) writeOpenAIWSFallbackErrorResponse(c *gin.Context, account *Account, wsErr error) bool {
+ if c == nil || c.Writer == nil || c.Writer.Written() {
+ return false
+ }
+ statusCode, errType, clientMessage, upstreamMessage, ok := resolveOpenAIWSFallbackErrorResponse(wsErr)
+ if !ok {
+ return false
+ }
+ if strings.TrimSpace(clientMessage) == "" {
+ clientMessage = "Upstream request failed"
+ }
+ if strings.TrimSpace(upstreamMessage) == "" {
+ upstreamMessage = clientMessage
+ }
+
+ setOpsUpstreamError(c, statusCode, upstreamMessage, "")
+ if account != nil {
+ appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
+ Platform: account.Platform,
+ AccountID: account.ID,
+ AccountName: account.Name,
+ UpstreamStatusCode: statusCode,
+ Kind: "ws_error",
+ Message: upstreamMessage,
+ })
+ }
+ c.JSON(statusCode, gin.H{
+ "error": gin.H{
+ "type": errType,
+ "message": clientMessage,
+ },
+ })
+ return true
+}
+
+func (s *OpenAIGatewayService) openAIWSRetryBackoff(attempt int) time.Duration {
+ if attempt <= 0 {
+ return 0
+ }
+
+ initial := openAIWSRetryBackoffInitialDefault
+ maxBackoff := openAIWSRetryBackoffMaxDefault
+ jitterRatio := openAIWSRetryJitterRatioDefault
+ if s != nil && s.cfg != nil {
+ wsCfg := s.cfg.Gateway.OpenAIWS
+ if wsCfg.RetryBackoffInitialMS > 0 {
+ initial = time.Duration(wsCfg.RetryBackoffInitialMS) * time.Millisecond
+ }
+ if wsCfg.RetryBackoffMaxMS > 0 {
+ maxBackoff = time.Duration(wsCfg.RetryBackoffMaxMS) * time.Millisecond
+ }
+ if wsCfg.RetryJitterRatio >= 0 {
+ jitterRatio = wsCfg.RetryJitterRatio
+ }
+ }
+ if initial <= 0 {
+ return 0
+ }
+ if maxBackoff <= 0 {
+ maxBackoff = initial
+ }
+ if maxBackoff < initial {
+ maxBackoff = initial
+ }
+ if jitterRatio < 0 {
+ jitterRatio = 0
+ }
+ if jitterRatio > 1 {
+ jitterRatio = 1
+ }
+
+ shift := attempt - 1
+ if shift < 0 {
+ shift = 0
+ }
+ backoff := initial
+ if shift > 0 {
+ backoff = initial * time.Duration(1< maxBackoff {
+ backoff = maxBackoff
+ }
+ if jitterRatio <= 0 {
+ return backoff
+ }
+ jitter := time.Duration(float64(backoff) * jitterRatio)
+ if jitter <= 0 {
+ return backoff
+ }
+ delta := time.Duration(rand.Int63n(int64(jitter)*2+1)) - jitter
+ withJitter := backoff + delta
+ if withJitter < 0 {
+ return 0
+ }
+ return withJitter
+}
+
+func (s *OpenAIGatewayService) openAIWSRetryTotalBudget() time.Duration {
+ if s != nil && s.cfg != nil {
+ ms := s.cfg.Gateway.OpenAIWS.RetryTotalBudgetMS
+ if ms <= 0 {
+ return 0
+ }
+ return time.Duration(ms) * time.Millisecond
+ }
+ return 0
+}
+
+func (s *OpenAIGatewayService) recordOpenAIWSRetryAttempt(backoff time.Duration) {
+ if s == nil {
+ return
+ }
+ s.openaiWSRetryMetrics.retryAttempts.Add(1)
+ if backoff > 0 {
+ s.openaiWSRetryMetrics.retryBackoffMs.Add(backoff.Milliseconds())
+ }
+}
+
+func (s *OpenAIGatewayService) recordOpenAIWSRetryExhausted() {
+ if s == nil {
+ return
+ }
+ s.openaiWSRetryMetrics.retryExhausted.Add(1)
+}
+
+func (s *OpenAIGatewayService) recordOpenAIWSNonRetryableFastFallback() {
+ if s == nil {
+ return
+ }
+ s.openaiWSRetryMetrics.nonRetryableFastFallback.Add(1)
+}
+
+func (s *OpenAIGatewayService) SnapshotOpenAIWSRetryMetrics() OpenAIWSRetryMetricsSnapshot {
+ if s == nil {
+ return OpenAIWSRetryMetricsSnapshot{}
+ }
+ return OpenAIWSRetryMetricsSnapshot{
+ RetryAttemptsTotal: s.openaiWSRetryMetrics.retryAttempts.Load(),
+ RetryBackoffMsTotal: s.openaiWSRetryMetrics.retryBackoffMs.Load(),
+ RetryExhaustedTotal: s.openaiWSRetryMetrics.retryExhausted.Load(),
+ NonRetryableFastFallbackTotal: s.openaiWSRetryMetrics.nonRetryableFastFallback.Load(),
+ }
+}
+
+func SnapshotOpenAICompatibilityFallbackMetrics() OpenAICompatibilityFallbackMetricsSnapshot {
+ legacyReadFallbackTotal, legacyReadFallbackHit, legacyDualWriteTotal := openAIStickyCompatStats()
+ isMaxTokensOneHaiku, thinkingEnabled, prefetchedStickyAccount, prefetchedStickyGroup, singleAccountRetry, accountSwitchCount := RequestMetadataFallbackStats()
+
+ readHitRate := float64(0)
+ if legacyReadFallbackTotal > 0 {
+ readHitRate = float64(legacyReadFallbackHit) / float64(legacyReadFallbackTotal)
+ }
+ metadataFallbackTotal := isMaxTokensOneHaiku + thinkingEnabled + prefetchedStickyAccount + prefetchedStickyGroup + singleAccountRetry + accountSwitchCount
+
+ return OpenAICompatibilityFallbackMetricsSnapshot{
+ SessionHashLegacyReadFallbackTotal: legacyReadFallbackTotal,
+ SessionHashLegacyReadFallbackHit: legacyReadFallbackHit,
+ SessionHashLegacyDualWriteTotal: legacyDualWriteTotal,
+ SessionHashLegacyReadHitRate: readHitRate,
+
+ MetadataLegacyFallbackIsMaxTokensOneHaikuTotal: isMaxTokensOneHaiku,
+ MetadataLegacyFallbackThinkingEnabledTotal: thinkingEnabled,
+ MetadataLegacyFallbackPrefetchedStickyAccount: prefetchedStickyAccount,
+ MetadataLegacyFallbackPrefetchedStickyGroup: prefetchedStickyGroup,
+ MetadataLegacyFallbackSingleAccountRetryTotal: singleAccountRetry,
+ MetadataLegacyFallbackAccountSwitchCountTotal: accountSwitchCount,
+ MetadataLegacyFallbackTotal: metadataFallbackTotal,
+ }
+}
+
func (s *OpenAIGatewayService) detectCodexClientRestriction(c *gin.Context, account *Account) CodexClientRestrictionDetectionResult {
return s.getCodexClientRestrictionDetector().Detect(c, account)
}
@@ -494,8 +898,28 @@ func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context, body []byte)
return ""
}
- hash := sha256.Sum256([]byte(sessionID))
- return hex.EncodeToString(hash[:])
+ currentHash, legacyHash := deriveOpenAISessionHashes(sessionID)
+ attachOpenAILegacySessionHashToGin(c, legacyHash)
+ return currentHash
+}
+
+// GenerateSessionHashWithFallback 先按常规信号生成会话哈希;
+// 当未携带 session_id/conversation_id/prompt_cache_key 时,使用 fallbackSeed 生成稳定哈希。
+// 该方法用于 WS ingress,避免会话信号缺失时发生跨账号漂移。
+func (s *OpenAIGatewayService) GenerateSessionHashWithFallback(c *gin.Context, body []byte, fallbackSeed string) string {
+ sessionHash := s.GenerateSessionHash(c, body)
+ if sessionHash != "" {
+ return sessionHash
+ }
+
+ seed := strings.TrimSpace(fallbackSeed)
+ if seed == "" {
+ return ""
+ }
+
+ currentHash, legacyHash := deriveOpenAISessionHashes(seed)
+ attachOpenAILegacySessionHashToGin(c, legacyHash)
+ return currentHash
}
// BindStickySession sets session -> account binding with standard TTL.
@@ -503,7 +927,11 @@ func (s *OpenAIGatewayService) BindStickySession(ctx context.Context, groupID *i
if sessionHash == "" || accountID <= 0 {
return nil
}
- return s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash, accountID, openaiStickySessionTTL)
+ ttl := openaiStickySessionTTL
+ if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.StickySessionTTLSeconds > 0 {
+ ttl = time.Duration(s.cfg.Gateway.OpenAIWS.StickySessionTTLSeconds) * time.Second
+ }
+ return s.setStickySessionAccountID(ctx, groupID, sessionHash, accountID, ttl)
}
// SelectAccount selects an OpenAI account with sticky session support
@@ -519,11 +947,13 @@ func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupI
// SelectAccountForModelWithExclusions selects an account supporting the requested model while excluding specified accounts.
// SelectAccountForModelWithExclusions 选择支持指定模型的账号,同时排除指定的账号。
func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
- cacheKey := "openai:" + sessionHash
+ return s.selectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs, 0)
+}
+func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, stickyAccountID int64) (*Account, error) {
// 1. 尝试粘性会话命中
// Try sticky session hit
- if account := s.tryStickySessionHit(ctx, groupID, sessionHash, cacheKey, requestedModel, excludedIDs); account != nil {
+ if account := s.tryStickySessionHit(ctx, groupID, sessionHash, requestedModel, excludedIDs, stickyAccountID); account != nil {
return account, nil
}
@@ -548,7 +978,7 @@ func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.C
// 4. 设置粘性会话绑定
// Set sticky session binding
if sessionHash != "" {
- _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), cacheKey, selected.ID, openaiStickySessionTTL)
+ _ = s.setStickySessionAccountID(ctx, groupID, sessionHash, selected.ID, openaiStickySessionTTL)
}
return selected, nil
@@ -559,14 +989,18 @@ func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.C
//
// tryStickySessionHit attempts to get account from sticky session.
// Returns account if hit and usable; clears session and returns nil if account is unavailable.
-func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID *int64, sessionHash, cacheKey, requestedModel string, excludedIDs map[int64]struct{}) *Account {
+func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID *int64, sessionHash, requestedModel string, excludedIDs map[int64]struct{}, stickyAccountID int64) *Account {
if sessionHash == "" {
return nil
}
- accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), cacheKey)
- if err != nil || accountID <= 0 {
- return nil
+ accountID := stickyAccountID
+ if accountID <= 0 {
+ var err error
+ accountID, err = s.getStickySessionAccountID(ctx, groupID, sessionHash)
+ if err != nil || accountID <= 0 {
+ return nil
+ }
}
if _, excluded := excludedIDs[accountID]; excluded {
@@ -581,7 +1015,7 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID
// 检查账号是否需要清理粘性会话
// Check if sticky session should be cleared
if shouldClearStickySession(account, requestedModel) {
- _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), cacheKey)
+ _ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
return nil
}
@@ -596,7 +1030,7 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID
// 刷新会话 TTL 并返回账号
// Refresh session TTL and return account
- _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), cacheKey, openaiStickySessionTTL)
+ _ = s.refreshStickySessionTTL(ctx, groupID, sessionHash, openaiStickySessionTTL)
return account
}
@@ -682,12 +1116,12 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
cfg := s.schedulingConfig()
var stickyAccountID int64
if sessionHash != "" && s.cache != nil {
- if accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash); err == nil {
+ if accountID, err := s.getStickySessionAccountID(ctx, groupID, sessionHash); err == nil {
stickyAccountID = accountID
}
}
if s.concurrencyService == nil || !cfg.LoadBatchEnabled {
- account, err := s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs)
+ account, err := s.selectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs, stickyAccountID)
if err != nil {
return nil, err
}
@@ -742,19 +1176,19 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
// ============ Layer 1: Sticky session ============
if sessionHash != "" {
- accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash)
- if err == nil && accountID > 0 && !isExcluded(accountID) {
+ accountID := stickyAccountID
+ if accountID > 0 && !isExcluded(accountID) {
account, err := s.getSchedulableAccount(ctx, accountID)
if err == nil {
clearSticky := shouldClearStickySession(account, requestedModel)
if clearSticky {
- _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash)
+ _ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
}
if !clearSticky && account.IsSchedulable() && account.IsOpenAI() &&
(requestedModel == "" || account.IsModelSupported(requestedModel)) {
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if err == nil && result.Acquired {
- _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), "openai:"+sessionHash, openaiStickySessionTTL)
+ _ = s.refreshStickySessionTTL(ctx, groupID, sessionHash, openaiStickySessionTTL)
return &AccountSelectionResult{
Account: account,
Acquired: true,
@@ -818,7 +1252,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency)
if err == nil && result.Acquired {
if sessionHash != "" {
- _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash, acc.ID, openaiStickySessionTTL)
+ _ = s.setStickySessionAccountID(ctx, groupID, sessionHash, acc.ID, openaiStickySessionTTL)
}
return &AccountSelectionResult{
Account: acc,
@@ -868,7 +1302,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency)
if err == nil && result.Acquired {
if sessionHash != "" {
- _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash, item.account.ID, openaiStickySessionTTL)
+ _ = s.setStickySessionAccountID(ctx, groupID, sessionHash, item.account.ID, openaiStickySessionTTL)
}
return &AccountSelectionResult{
Account: item.account,
@@ -1010,6 +1444,37 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
originalModel := reqModel
isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent")) || (s.cfg != nil && s.cfg.Gateway.ForceCodexCLI)
+ wsDecision := s.getOpenAIWSProtocolResolver().Resolve(account)
+ clientTransport := GetOpenAIClientTransport(c)
+ // 仅允许 WS 入站请求走 WS 上游,避免出现 HTTP -> WS 协议混用。
+ wsDecision = resolveOpenAIWSDecisionByClientTransport(wsDecision, clientTransport)
+ if c != nil {
+ c.Set("openai_ws_transport_decision", string(wsDecision.Transport))
+ c.Set("openai_ws_transport_reason", wsDecision.Reason)
+ }
+ if wsDecision.Transport == OpenAIUpstreamTransportResponsesWebsocketV2 {
+ logOpenAIWSModeDebug(
+ "selected account_id=%d account_type=%s transport=%s reason=%s model=%s stream=%v",
+ account.ID,
+ account.Type,
+ normalizeOpenAIWSLogValue(string(wsDecision.Transport)),
+ normalizeOpenAIWSLogValue(wsDecision.Reason),
+ reqModel,
+ reqStream,
+ )
+ }
+ // 当前仅支持 WSv2;WSv1 命中时直接返回错误,避免出现“配置可开但行为不确定”。
+ if wsDecision.Transport == OpenAIUpstreamTransportResponsesWebsocket {
+ if c != nil {
+ c.JSON(http.StatusBadRequest, gin.H{
+ "error": gin.H{
+ "type": "invalid_request_error",
+ "message": "OpenAI WSv1 is temporarily unsupported. Please enable responses_websockets_v2.",
+ },
+ })
+ }
+ return nil, errors.New("openai ws v1 is temporarily unsupported; use ws v2")
+ }
passthroughEnabled := account.IsOpenAIPassthroughEnabled()
if passthroughEnabled {
// 透传分支只需要轻量提取字段,避免热路径全量 Unmarshal。
@@ -1037,12 +1502,61 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
// Track if body needs re-serialization
bodyModified := false
+ // 单字段补丁快速路径:只要整个变更集最终可归约为同一路径的 set/delete,就避免全量 Marshal。
+ patchDisabled := false
+ patchHasOp := false
+ patchDelete := false
+ patchPath := ""
+ var patchValue any
+ markPatchSet := func(path string, value any) {
+ if strings.TrimSpace(path) == "" {
+ patchDisabled = true
+ return
+ }
+ if patchDisabled {
+ return
+ }
+ if !patchHasOp {
+ patchHasOp = true
+ patchDelete = false
+ patchPath = path
+ patchValue = value
+ return
+ }
+ if patchDelete || patchPath != path {
+ patchDisabled = true
+ return
+ }
+ patchValue = value
+ }
+ markPatchDelete := func(path string) {
+ if strings.TrimSpace(path) == "" {
+ patchDisabled = true
+ return
+ }
+ if patchDisabled {
+ return
+ }
+ if !patchHasOp {
+ patchHasOp = true
+ patchDelete = true
+ patchPath = path
+ return
+ }
+ if !patchDelete || patchPath != path {
+ patchDisabled = true
+ }
+ }
+ disablePatch := func() {
+ patchDisabled = true
+ }
// 非透传模式下,保持历史行为:非 Codex CLI 请求在 instructions 为空时注入默认指令。
if !isCodexCLI && isInstructionsEmpty(reqBody) {
if instructions := strings.TrimSpace(GetOpenCodeInstructions()); instructions != "" {
reqBody["instructions"] = instructions
bodyModified = true
+ markPatchSet("instructions", instructions)
}
}
@@ -1052,6 +1566,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Model mapping applied: %s -> %s (account: %s, isCodexCLI: %v)", reqModel, mappedModel, account.Name, isCodexCLI)
reqBody["model"] = mappedModel
bodyModified = true
+ markPatchSet("model", mappedModel)
}
// 针对所有 OpenAI 账号执行 Codex 模型名规范化,确保上游识别一致。
@@ -1063,6 +1578,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
reqBody["model"] = normalizedModel
mappedModel = normalizedModel
bodyModified = true
+ markPatchSet("model", normalizedModel)
}
}
@@ -1071,6 +1587,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
if effort, ok := reasoning["effort"].(string); ok && effort == "minimal" {
reasoning["effort"] = "none"
bodyModified = true
+ markPatchSet("reasoning.effort", "none")
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Normalized reasoning.effort: minimal -> none (account: %s)", account.Name)
}
}
@@ -1079,6 +1596,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
codexResult := applyCodexOAuthTransform(reqBody, isCodexCLI)
if codexResult.Modified {
bodyModified = true
+ disablePatch()
}
if codexResult.NormalizedModel != "" {
mappedModel = codexResult.NormalizedModel
@@ -1098,22 +1616,27 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
if account.Type == AccountTypeAPIKey {
delete(reqBody, "max_output_tokens")
bodyModified = true
+ markPatchDelete("max_output_tokens")
}
case PlatformAnthropic:
// For Anthropic (Claude), convert to max_tokens
delete(reqBody, "max_output_tokens")
+ markPatchDelete("max_output_tokens")
if _, hasMaxTokens := reqBody["max_tokens"]; !hasMaxTokens {
reqBody["max_tokens"] = maxOutputTokens
+ disablePatch()
}
bodyModified = true
case PlatformGemini:
// For Gemini, remove (will be handled by Gemini-specific transform)
delete(reqBody, "max_output_tokens")
bodyModified = true
+ markPatchDelete("max_output_tokens")
default:
// For unknown platforms, remove to be safe
delete(reqBody, "max_output_tokens")
bodyModified = true
+ markPatchDelete("max_output_tokens")
}
}
@@ -1122,24 +1645,51 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
if account.Type == AccountTypeAPIKey || account.Platform != PlatformOpenAI {
delete(reqBody, "max_completion_tokens")
bodyModified = true
+ markPatchDelete("max_completion_tokens")
}
}
// Remove unsupported fields (not supported by upstream OpenAI API)
- for _, unsupportedField := range []string{"prompt_cache_retention", "safety_identifier", "previous_response_id"} {
+ unsupportedFields := []string{"prompt_cache_retention", "safety_identifier"}
+ for _, unsupportedField := range unsupportedFields {
if _, has := reqBody[unsupportedField]; has {
delete(reqBody, unsupportedField)
bodyModified = true
+ markPatchDelete(unsupportedField)
}
}
}
+ // 仅在 WSv2 模式保留 previous_response_id,其他模式(HTTP/WSv1)统一过滤。
+ // 注意:该规则同样适用于 Codex CLI 请求,避免 WSv1 向上游透传不支持字段。
+ if wsDecision.Transport != OpenAIUpstreamTransportResponsesWebsocketV2 {
+ if _, has := reqBody["previous_response_id"]; has {
+ delete(reqBody, "previous_response_id")
+ bodyModified = true
+ markPatchDelete("previous_response_id")
+ }
+ }
+
// Re-serialize body only if modified
if bodyModified {
- var err error
- body, err = json.Marshal(reqBody)
- if err != nil {
- return nil, fmt.Errorf("serialize request body: %w", err)
+ serializedByPatch := false
+ if !patchDisabled && patchHasOp {
+ var patchErr error
+ if patchDelete {
+ body, patchErr = sjson.DeleteBytes(body, patchPath)
+ } else {
+ body, patchErr = sjson.SetBytes(body, patchPath, patchValue)
+ }
+ if patchErr == nil {
+ serializedByPatch = true
+ }
+ }
+ if !serializedByPatch {
+ var marshalErr error
+ body, marshalErr = json.Marshal(reqBody)
+ if marshalErr != nil {
+ return nil, fmt.Errorf("serialize request body: %w", marshalErr)
+ }
}
}
@@ -1149,6 +1699,184 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
return nil, err
}
+ // Capture upstream request body for ops retry of this attempt.
+ setOpsUpstreamRequestBody(c, body)
+
+ // 命中 WS 时仅走 WebSocket Mode;不再自动回退 HTTP。
+ if wsDecision.Transport == OpenAIUpstreamTransportResponsesWebsocketV2 {
+ wsReqBody := reqBody
+ if len(reqBody) > 0 {
+ wsReqBody = make(map[string]any, len(reqBody))
+ for k, v := range reqBody {
+ wsReqBody[k] = v
+ }
+ }
+ _, hasPreviousResponseID := wsReqBody["previous_response_id"]
+ logOpenAIWSModeDebug(
+ "forward_start account_id=%d account_type=%s model=%s stream=%v has_previous_response_id=%v",
+ account.ID,
+ account.Type,
+ mappedModel,
+ reqStream,
+ hasPreviousResponseID,
+ )
+ maxAttempts := openAIWSReconnectRetryLimit + 1
+ wsAttempts := 0
+ var wsResult *OpenAIForwardResult
+ var wsErr error
+ wsLastFailureReason := ""
+ wsPrevResponseRecoveryTried := false
+ recoverPrevResponseNotFound := func(attempt int) bool {
+ if wsPrevResponseRecoveryTried {
+ return false
+ }
+ previousResponseID := openAIWSPayloadString(wsReqBody, "previous_response_id")
+ if previousResponseID == "" {
+ logOpenAIWSModeInfo(
+ "reconnect_prev_response_recovery_skip account_id=%d attempt=%d reason=missing_previous_response_id previous_response_id_present=false",
+ account.ID,
+ attempt,
+ )
+ return false
+ }
+ if HasFunctionCallOutput(wsReqBody) {
+ logOpenAIWSModeInfo(
+ "reconnect_prev_response_recovery_skip account_id=%d attempt=%d reason=has_function_call_output previous_response_id_present=true",
+ account.ID,
+ attempt,
+ )
+ return false
+ }
+ delete(wsReqBody, "previous_response_id")
+ wsPrevResponseRecoveryTried = true
+ logOpenAIWSModeInfo(
+ "reconnect_prev_response_recovery account_id=%d attempt=%d action=drop_previous_response_id retry=1 previous_response_id=%s previous_response_id_kind=%s",
+ account.ID,
+ attempt,
+ truncateOpenAIWSLogValue(previousResponseID, openAIWSIDValueMaxLen),
+ normalizeOpenAIWSLogValue(ClassifyOpenAIPreviousResponseIDKind(previousResponseID)),
+ )
+ return true
+ }
+ retryBudget := s.openAIWSRetryTotalBudget()
+ retryStartedAt := time.Now()
+ wsRetryLoop:
+ for attempt := 1; attempt <= maxAttempts; attempt++ {
+ wsAttempts = attempt
+ wsResult, wsErr = s.forwardOpenAIWSV2(
+ ctx,
+ c,
+ account,
+ wsReqBody,
+ token,
+ wsDecision,
+ isCodexCLI,
+ reqStream,
+ originalModel,
+ mappedModel,
+ startTime,
+ attempt,
+ wsLastFailureReason,
+ )
+ if wsErr == nil {
+ break
+ }
+ if c != nil && c.Writer != nil && c.Writer.Written() {
+ break
+ }
+
+ reason, retryable := classifyOpenAIWSReconnectReason(wsErr)
+ if reason != "" {
+ wsLastFailureReason = reason
+ }
+ // previous_response_not_found 说明续链锚点不可用:
+ // 对非 function_call_output 场景,允许一次“去掉 previous_response_id 后重放”。
+ if reason == "previous_response_not_found" && recoverPrevResponseNotFound(attempt) {
+ continue
+ }
+ if retryable && attempt < maxAttempts {
+ backoff := s.openAIWSRetryBackoff(attempt)
+ if retryBudget > 0 && time.Since(retryStartedAt)+backoff > retryBudget {
+ s.recordOpenAIWSRetryExhausted()
+ logOpenAIWSModeInfo(
+ "reconnect_budget_exhausted account_id=%d attempts=%d max_retries=%d reason=%s elapsed_ms=%d budget_ms=%d",
+ account.ID,
+ attempt,
+ openAIWSReconnectRetryLimit,
+ normalizeOpenAIWSLogValue(reason),
+ time.Since(retryStartedAt).Milliseconds(),
+ retryBudget.Milliseconds(),
+ )
+ break
+ }
+ s.recordOpenAIWSRetryAttempt(backoff)
+ logOpenAIWSModeInfo(
+ "reconnect_retry account_id=%d retry=%d max_retries=%d reason=%s backoff_ms=%d",
+ account.ID,
+ attempt,
+ openAIWSReconnectRetryLimit,
+ normalizeOpenAIWSLogValue(reason),
+ backoff.Milliseconds(),
+ )
+ if backoff > 0 {
+ timer := time.NewTimer(backoff)
+ select {
+ case <-ctx.Done():
+ if !timer.Stop() {
+ <-timer.C
+ }
+ wsErr = wrapOpenAIWSFallback("retry_backoff_canceled", ctx.Err())
+ break wsRetryLoop
+ case <-timer.C:
+ }
+ }
+ continue
+ }
+ if retryable {
+ s.recordOpenAIWSRetryExhausted()
+ logOpenAIWSModeInfo(
+ "reconnect_exhausted account_id=%d attempts=%d max_retries=%d reason=%s",
+ account.ID,
+ attempt,
+ openAIWSReconnectRetryLimit,
+ normalizeOpenAIWSLogValue(reason),
+ )
+ } else if reason != "" {
+ s.recordOpenAIWSNonRetryableFastFallback()
+ logOpenAIWSModeInfo(
+ "reconnect_stop account_id=%d attempt=%d reason=%s",
+ account.ID,
+ attempt,
+ normalizeOpenAIWSLogValue(reason),
+ )
+ }
+ break
+ }
+ if wsErr == nil {
+ firstTokenMs := int64(0)
+ hasFirstTokenMs := wsResult != nil && wsResult.FirstTokenMs != nil
+ if hasFirstTokenMs {
+ firstTokenMs = int64(*wsResult.FirstTokenMs)
+ }
+ requestID := ""
+ if wsResult != nil {
+ requestID = strings.TrimSpace(wsResult.RequestID)
+ }
+ logOpenAIWSModeDebug(
+ "forward_succeeded account_id=%d request_id=%s stream=%v has_first_token_ms=%v first_token_ms=%d ws_attempts=%d",
+ account.ID,
+ requestID,
+ reqStream,
+ hasFirstTokenMs,
+ firstTokenMs,
+ wsAttempts,
+ )
+ return wsResult, nil
+ }
+ s.writeOpenAIWSFallbackErrorResponse(c, account, wsErr)
+ return nil, wsErr
+ }
+
// Build upstream request
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, reqStream, promptCacheKey, isCodexCLI)
if err != nil {
@@ -1161,9 +1889,6 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
proxyURL = account.Proxy.URL()
}
- // Capture upstream request body for ops retry of this attempt.
- setOpsUpstreamRequestBody(c, body)
-
// Send request
upstreamStart := time.Now()
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
@@ -1260,6 +1985,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
Model: originalModel,
ReasoningEffort: reasoningEffort,
Stream: reqStream,
+ OpenAIWSMode: false,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
}, nil
@@ -1413,6 +2139,7 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
Model: reqModel,
ReasoningEffort: reasoningEffort,
Stream: reqStream,
+ OpenAIWSMode: false,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
}, nil
@@ -1576,7 +2303,7 @@ func (s *OpenAIGatewayService) handleErrorResponsePassthrough(
UpstreamResponseBody: upstreamDetail,
})
- writeOpenAIPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.cfg)
+ writeOpenAIPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
contentType := resp.Header.Get("Content-Type")
if contentType == "" {
contentType = "application/json"
@@ -1643,7 +2370,7 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
account *Account,
startTime time.Time,
) (*openaiStreamingResultPassthrough, error) {
- writeOpenAIPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.cfg)
+ writeOpenAIPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
// SSE headers
c.Header("Content-Type", "text/event-stream")
@@ -1678,6 +2405,7 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
for scanner.Scan() {
line := scanner.Text()
if data, ok := extractOpenAISSEDataLine(line); ok {
+ dataBytes := []byte(data)
trimmedData := strings.TrimSpace(data)
if trimmedData == "[DONE]" {
sawDone = true
@@ -1686,7 +2414,7 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
- s.parseSSEUsage(data, usage)
+ s.parseSSEUsageBytes(dataBytes, usage)
}
if !clientDisconnected {
@@ -1759,19 +2487,8 @@ func (s *OpenAIGatewayService) handleNonStreamingResponsePassthrough(
usage := &OpenAIUsage{}
usageParsed := false
if len(body) > 0 {
- var response struct {
- Usage struct {
- InputTokens int `json:"input_tokens"`
- OutputTokens int `json:"output_tokens"`
- InputTokenDetails struct {
- CachedTokens int `json:"cached_tokens"`
- } `json:"input_tokens_details"`
- } `json:"usage"`
- }
- if json.Unmarshal(body, &response) == nil {
- usage.InputTokens = response.Usage.InputTokens
- usage.OutputTokens = response.Usage.OutputTokens
- usage.CacheReadInputTokens = response.Usage.InputTokenDetails.CachedTokens
+ if parsedUsage, ok := extractOpenAIUsageFromJSONBytes(body); ok {
+ *usage = parsedUsage
usageParsed = true
}
}
@@ -1780,7 +2497,7 @@ func (s *OpenAIGatewayService) handleNonStreamingResponsePassthrough(
usage = s.parseSSEUsageFromBody(string(body))
}
- writeOpenAIPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.cfg)
+ writeOpenAIPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
contentType := resp.Header.Get("Content-Type")
if contentType == "" {
@@ -1790,12 +2507,12 @@ func (s *OpenAIGatewayService) handleNonStreamingResponsePassthrough(
return usage, nil
}
-func writeOpenAIPassthroughResponseHeaders(dst http.Header, src http.Header, cfg *config.Config) {
+func writeOpenAIPassthroughResponseHeaders(dst http.Header, src http.Header, filter *responseheaders.CompiledHeaderFilter) {
if dst == nil || src == nil {
return
}
- if cfg != nil {
- responseheaders.WriteFilteredHeaders(dst, src, cfg.Security.ResponseHeaders)
+ if filter != nil {
+ responseheaders.WriteFilteredHeaders(dst, src, filter)
} else {
// 兜底:尽量保留最基础的 content-type
if v := strings.TrimSpace(src.Get("Content-Type")); v != "" {
@@ -2074,8 +2791,8 @@ type openaiStreamingResult struct {
}
func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string) (*openaiStreamingResult, error) {
- if s.cfg != nil {
- responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
+ if s.responseHeaderFilter != nil {
+ responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
}
// Set SSE response headers
@@ -2094,6 +2811,14 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
if !ok {
return nil, errors.New("streaming not supported")
}
+ bufferedWriter := bufio.NewWriterSize(w, 4*1024)
+ flushBuffered := func() error {
+ if err := bufferedWriter.Flush(); err != nil {
+ return err
+ }
+ flusher.Flush()
+ return nil
+ }
usage := &OpenAIUsage{}
var firstTokenMs *int
@@ -2105,38 +2830,6 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
scanBuf := getSSEScannerBuf64K()
scanner.Buffer(scanBuf[:0], maxLineSize)
- type scanEvent struct {
- line string
- err error
- }
- // 独立 goroutine 读取上游,避免读取阻塞影响 keepalive/超时处理
- events := make(chan scanEvent, 16)
- done := make(chan struct{})
- sendEvent := func(ev scanEvent) bool {
- select {
- case events <- ev:
- return true
- case <-done:
- return false
- }
- }
- var lastReadAt int64
- atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
- go func(scanBuf *sseScannerBuf64K) {
- defer putSSEScannerBuf64K(scanBuf)
- defer close(events)
- for scanner.Scan() {
- atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
- if !sendEvent(scanEvent{line: scanner.Text()}) {
- return
- }
- }
- if err := scanner.Err(); err != nil {
- _ = sendEvent(scanEvent{err: err})
- }
- }(scanBuf)
- defer close(done)
-
streamInterval := time.Duration(0)
if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 {
streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
@@ -2179,94 +2872,178 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
return
}
errorEventSent = true
- payload := map[string]any{
- "type": "error",
- "sequence_number": 0,
- "error": map[string]any{
- "type": "upstream_error",
- "message": reason,
- "code": reason,
- },
+ payload := `{"type":"error","sequence_number":0,"error":{"type":"upstream_error","message":` + strconv.Quote(reason) + `,"code":` + strconv.Quote(reason) + `}}`
+ if err := flushBuffered(); err != nil {
+ clientDisconnected = true
+ return
}
- if b, err := json.Marshal(payload); err == nil {
- _, _ = fmt.Fprintf(w, "data: %s\n\n", b)
- flusher.Flush()
+ if _, err := bufferedWriter.WriteString("data: " + payload + "\n\n"); err != nil {
+ clientDisconnected = true
+ return
+ }
+ if err := flushBuffered(); err != nil {
+ clientDisconnected = true
}
}
needModelReplace := originalModel != mappedModel
+ resultWithUsage := func() *openaiStreamingResult {
+ return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}
+ }
+ finalizeStream := func() (*openaiStreamingResult, error) {
+ if !clientDisconnected {
+ if err := flushBuffered(); err != nil {
+ clientDisconnected = true
+ logger.LegacyPrintf("service.openai_gateway", "Client disconnected during final flush, returning collected usage")
+ }
+ }
+ return resultWithUsage(), nil
+ }
+ handleScanErr := func(scanErr error) (*openaiStreamingResult, error, bool) {
+ if scanErr == nil {
+ return nil, nil, false
+ }
+ // 客户端断开/取消请求时,上游读取往往会返回 context canceled。
+ // /v1/responses 的 SSE 事件必须符合 OpenAI 协议;这里不注入自定义 error event,避免下游 SDK 解析失败。
+ if errors.Is(scanErr, context.Canceled) || errors.Is(scanErr, context.DeadlineExceeded) {
+ logger.LegacyPrintf("service.openai_gateway", "Context canceled during streaming, returning collected usage")
+ return resultWithUsage(), nil, true
+ }
+ // 客户端已断开时,上游出错仅影响体验,不影响计费;返回已收集 usage
+ if clientDisconnected {
+ logger.LegacyPrintf("service.openai_gateway", "Upstream read error after client disconnect: %v, returning collected usage", scanErr)
+ return resultWithUsage(), nil, true
+ }
+ if errors.Is(scanErr, bufio.ErrTooLong) {
+ logger.LegacyPrintf("service.openai_gateway", "SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, scanErr)
+ sendErrorEvent("response_too_large")
+ return resultWithUsage(), scanErr, true
+ }
+ sendErrorEvent("stream_read_error")
+ return resultWithUsage(), fmt.Errorf("stream read error: %w", scanErr), true
+ }
+ processSSELine := func(line string, queueDrained bool) {
+ lastDataAt = time.Now()
+
+ // Extract data from SSE line (supports both "data: " and "data:" formats)
+ if data, ok := extractOpenAISSEDataLine(line); ok {
+
+ // Replace model in response if needed.
+ // Fast path: most events do not contain model field values.
+ if needModelReplace && mappedModel != "" && strings.Contains(data, mappedModel) {
+ line = s.replaceModelInSSELine(line, mappedModel, originalModel)
+ }
+
+ dataBytes := []byte(data)
+
+ // Correct Codex tool calls if needed (apply_patch -> edit, etc.)
+ if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEBytes(dataBytes); corrected {
+ dataBytes = correctedData
+ data = string(correctedData)
+ line = "data: " + data
+ }
+
+ // 写入客户端(客户端断开后继续 drain 上游)
+ if !clientDisconnected {
+ shouldFlush := queueDrained
+ if firstTokenMs == nil && data != "" && data != "[DONE]" {
+ // 保证首个 token 事件尽快出站,避免影响 TTFT。
+ shouldFlush = true
+ }
+ if _, err := bufferedWriter.WriteString(line); err != nil {
+ clientDisconnected = true
+ logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming, continuing to drain upstream for billing")
+ } else if _, err := bufferedWriter.WriteString("\n"); err != nil {
+ clientDisconnected = true
+ logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming, continuing to drain upstream for billing")
+ } else if shouldFlush {
+ if err := flushBuffered(); err != nil {
+ clientDisconnected = true
+ logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming flush, continuing to drain upstream for billing")
+ }
+ }
+ }
+
+ // Record first token time
+ if firstTokenMs == nil && data != "" && data != "[DONE]" {
+ ms := int(time.Since(startTime).Milliseconds())
+ firstTokenMs = &ms
+ }
+ s.parseSSEUsageBytes(dataBytes, usage)
+ return
+ }
+
+ // Forward non-data lines as-is
+ if !clientDisconnected {
+ if _, err := bufferedWriter.WriteString(line); err != nil {
+ clientDisconnected = true
+ logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming, continuing to drain upstream for billing")
+ } else if _, err := bufferedWriter.WriteString("\n"); err != nil {
+ clientDisconnected = true
+ logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming, continuing to drain upstream for billing")
+ } else if queueDrained {
+ if err := flushBuffered(); err != nil {
+ clientDisconnected = true
+ logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming flush, continuing to drain upstream for billing")
+ }
+ }
+ }
+ }
+
+ // 无超时/无 keepalive 的常见路径走同步扫描,减少 goroutine 与 channel 开销。
+ if streamInterval <= 0 && keepaliveInterval <= 0 {
+ defer putSSEScannerBuf64K(scanBuf)
+ for scanner.Scan() {
+ processSSELine(scanner.Text(), true)
+ }
+ if result, err, done := handleScanErr(scanner.Err()); done {
+ return result, err
+ }
+ return finalizeStream()
+ }
+
+ type scanEvent struct {
+ line string
+ err error
+ }
+ // 独立 goroutine 读取上游,避免读取阻塞影响 keepalive/超时处理
+ events := make(chan scanEvent, 16)
+ done := make(chan struct{})
+ sendEvent := func(ev scanEvent) bool {
+ select {
+ case events <- ev:
+ return true
+ case <-done:
+ return false
+ }
+ }
+ var lastReadAt int64
+ atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
+ go func(scanBuf *sseScannerBuf64K) {
+ defer putSSEScannerBuf64K(scanBuf)
+ defer close(events)
+ for scanner.Scan() {
+ atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
+ if !sendEvent(scanEvent{line: scanner.Text()}) {
+ return
+ }
+ }
+ if err := scanner.Err(); err != nil {
+ _ = sendEvent(scanEvent{err: err})
+ }
+ }(scanBuf)
+ defer close(done)
for {
select {
case ev, ok := <-events:
if !ok {
- return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
+ return finalizeStream()
}
- if ev.err != nil {
- // 客户端断开/取消请求时,上游读取往往会返回 context canceled。
- // /v1/responses 的 SSE 事件必须符合 OpenAI 协议;这里不注入自定义 error event,避免下游 SDK 解析失败。
- if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) {
- logger.LegacyPrintf("service.openai_gateway", "Context canceled during streaming, returning collected usage")
- return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
- }
- // 客户端已断开时,上游出错仅影响体验,不影响计费;返回已收集 usage
- if clientDisconnected {
- logger.LegacyPrintf("service.openai_gateway", "Upstream read error after client disconnect: %v, returning collected usage", ev.err)
- return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
- }
- if errors.Is(ev.err, bufio.ErrTooLong) {
- logger.LegacyPrintf("service.openai_gateway", "SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err)
- sendErrorEvent("response_too_large")
- return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err
- }
- sendErrorEvent("stream_read_error")
- return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", ev.err)
- }
-
- line := ev.line
- lastDataAt = time.Now()
-
- // Extract data from SSE line (supports both "data: " and "data:" formats)
- if data, ok := extractOpenAISSEDataLine(line); ok {
-
- // Replace model in response if needed
- if needModelReplace {
- line = s.replaceModelInSSELine(line, mappedModel, originalModel)
- }
-
- // Correct Codex tool calls if needed (apply_patch -> edit, etc.)
- if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEData(data); corrected {
- data = correctedData
- line = "data: " + correctedData
- }
-
- // 写入客户端(客户端断开后继续 drain 上游)
- if !clientDisconnected {
- if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
- clientDisconnected = true
- logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming, continuing to drain upstream for billing")
- } else {
- flusher.Flush()
- }
- }
-
- // Record first token time
- if firstTokenMs == nil && data != "" && data != "[DONE]" {
- ms := int(time.Since(startTime).Milliseconds())
- firstTokenMs = &ms
- }
- s.parseSSEUsage(data, usage)
- } else {
- // Forward non-data lines as-is
- if !clientDisconnected {
- if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
- clientDisconnected = true
- logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming, continuing to drain upstream for billing")
- } else {
- flusher.Flush()
- }
- }
+ if result, err, done := handleScanErr(ev.err); done {
+ return result, err
}
+ processSSELine(ev.line, len(events) == 0)
case <-intervalCh:
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
@@ -2275,7 +3052,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
}
if clientDisconnected {
logger.LegacyPrintf("service.openai_gateway", "Upstream timeout after client disconnect, returning collected usage")
- return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
+ return resultWithUsage(), nil
}
logger.LegacyPrintf("service.openai_gateway", "Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval)
// 处理流超时,可能标记账户为临时不可调度或错误状态
@@ -2283,7 +3060,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
s.rateLimitService.HandleStreamTimeout(ctx, account, originalModel)
}
sendErrorEvent("stream_timeout")
- return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
+ return resultWithUsage(), fmt.Errorf("stream data interval timeout")
case <-keepaliveCh:
if clientDisconnected {
@@ -2292,12 +3069,15 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
if time.Since(lastDataAt) < keepaliveInterval {
continue
}
- if _, err := fmt.Fprint(w, ":\n\n"); err != nil {
+ if _, err := bufferedWriter.WriteString(":\n\n"); err != nil {
clientDisconnected = true
logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming, continuing to drain upstream for billing")
continue
}
- flusher.Flush()
+ if err := flushBuffered(); err != nil {
+ clientDisconnected = true
+ logger.LegacyPrintf("service.openai_gateway", "Client disconnected during keepalive flush, continuing to drain upstream for billing")
+ }
}
}
@@ -2355,29 +3135,49 @@ func (s *OpenAIGatewayService) correctToolCallsInResponseBody(body []byte) []byt
return body
}
- bodyStr := string(body)
- corrected, changed := s.toolCorrector.CorrectToolCallsInSSEData(bodyStr)
+ corrected, changed := s.toolCorrector.CorrectToolCallsInSSEBytes(body)
if changed {
- return []byte(corrected)
+ return corrected
}
return body
}
func (s *OpenAIGatewayService) parseSSEUsage(data string, usage *OpenAIUsage) {
- if usage == nil || data == "" || data == "[DONE]" {
+ s.parseSSEUsageBytes([]byte(data), usage)
+}
+
+func (s *OpenAIGatewayService) parseSSEUsageBytes(data []byte, usage *OpenAIUsage) {
+ if usage == nil || len(data) == 0 || bytes.Equal(data, []byte("[DONE]")) {
return
}
// 选择性解析:仅在数据中包含 completed 事件标识时才进入字段提取。
- if !strings.Contains(data, `"response.completed"`) {
+ if len(data) < 80 || !bytes.Contains(data, []byte(`"response.completed"`)) {
return
}
- if gjson.Get(data, "type").String() != "response.completed" {
+ if gjson.GetBytes(data, "type").String() != "response.completed" {
return
}
- usage.InputTokens = int(gjson.Get(data, "response.usage.input_tokens").Int())
- usage.OutputTokens = int(gjson.Get(data, "response.usage.output_tokens").Int())
- usage.CacheReadInputTokens = int(gjson.Get(data, "response.usage.input_tokens_details.cached_tokens").Int())
+ usage.InputTokens = int(gjson.GetBytes(data, "response.usage.input_tokens").Int())
+ usage.OutputTokens = int(gjson.GetBytes(data, "response.usage.output_tokens").Int())
+ usage.CacheReadInputTokens = int(gjson.GetBytes(data, "response.usage.input_tokens_details.cached_tokens").Int())
+}
+
+func extractOpenAIUsageFromJSONBytes(body []byte) (OpenAIUsage, bool) {
+ if len(body) == 0 || !gjson.ValidBytes(body) {
+ return OpenAIUsage{}, false
+ }
+ values := gjson.GetManyBytes(
+ body,
+ "usage.input_tokens",
+ "usage.output_tokens",
+ "usage.input_tokens_details.cached_tokens",
+ )
+ return OpenAIUsage{
+ InputTokens: int(values[0].Int()),
+ OutputTokens: int(values[1].Int()),
+ CacheReadInputTokens: int(values[2].Int()),
+ }, true
}
func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*OpenAIUsage, error) {
@@ -2403,32 +3203,18 @@ func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, r
}
}
- // Parse usage
- var response struct {
- Usage struct {
- InputTokens int `json:"input_tokens"`
- OutputTokens int `json:"output_tokens"`
- InputTokenDetails struct {
- CachedTokens int `json:"cached_tokens"`
- } `json:"input_tokens_details"`
- } `json:"usage"`
- }
- if err := json.Unmarshal(body, &response); err != nil {
- return nil, fmt.Errorf("parse response: %w", err)
- }
-
- usage := &OpenAIUsage{
- InputTokens: response.Usage.InputTokens,
- OutputTokens: response.Usage.OutputTokens,
- CacheReadInputTokens: response.Usage.InputTokenDetails.CachedTokens,
+ usageValue, usageOK := extractOpenAIUsageFromJSONBytes(body)
+ if !usageOK {
+ return nil, fmt.Errorf("parse response: invalid json response")
}
+ usage := &usageValue
// Replace model in response if needed
if originalModel != mappedModel {
body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
}
- responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
+ responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
contentType := "application/json"
if s.cfg != nil && !s.cfg.Security.ResponseHeaders.Enabled {
@@ -2453,19 +3239,8 @@ func (s *OpenAIGatewayService) handleOAuthSSEToJSON(resp *http.Response, c *gin.
usage := &OpenAIUsage{}
if ok {
- var response struct {
- Usage struct {
- InputTokens int `json:"input_tokens"`
- OutputTokens int `json:"output_tokens"`
- InputTokenDetails struct {
- CachedTokens int `json:"cached_tokens"`
- } `json:"input_tokens_details"`
- } `json:"usage"`
- }
- if err := json.Unmarshal(finalResponse, &response); err == nil {
- usage.InputTokens = response.Usage.InputTokens
- usage.OutputTokens = response.Usage.OutputTokens
- usage.CacheReadInputTokens = response.Usage.InputTokenDetails.CachedTokens
+ if parsedUsage, parsed := extractOpenAIUsageFromJSONBytes(finalResponse); parsed {
+ *usage = parsedUsage
}
body = finalResponse
if originalModel != mappedModel {
@@ -2481,7 +3256,7 @@ func (s *OpenAIGatewayService) handleOAuthSSEToJSON(resp *http.Response, c *gin.
body = []byte(bodyText)
}
- responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
+ responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
contentType := "application/json; charset=utf-8"
if !ok {
@@ -2505,16 +3280,10 @@ func extractCodexFinalResponse(body string) ([]byte, bool) {
if data == "" || data == "[DONE]" {
continue
}
- var event struct {
- Type string `json:"type"`
- Response json.RawMessage `json:"response"`
- }
- if json.Unmarshal([]byte(data), &event) != nil {
- continue
- }
- if event.Type == "response.done" || event.Type == "response.completed" {
- if len(event.Response) > 0 {
- return event.Response, true
+ eventType := gjson.Get(data, "type").String()
+ if eventType == "response.done" || eventType == "response.completed" {
+ if response := gjson.Get(data, "response"); response.Exists() && response.Type == gjson.JSON && response.Raw != "" {
+ return []byte(response.Raw), true
}
}
}
@@ -2532,7 +3301,7 @@ func (s *OpenAIGatewayService) parseSSEUsageFromBody(body string) *OpenAIUsage {
if data == "" || data == "[DONE]" {
continue
}
- s.parseSSEUsage(data, usage)
+ s.parseSSEUsageBytes([]byte(data), usage)
}
return usage
}
@@ -2671,6 +3440,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
AccountRateMultiplier: &accountRateMultiplier,
BillingType: billingType,
Stream: result.Stream,
+ OpenAIWSMode: result.OpenAIWSMode,
DurationMs: &durationMs,
FirstTokenMs: result.FirstTokenMs,
CreatedAt: time.Now(),
@@ -3047,6 +3817,9 @@ func getOpenAIRequestBodyMap(c *gin.Context, body []byte) (map[string]any, error
if err := json.Unmarshal(body, &reqBody); err != nil {
return nil, fmt.Errorf("parse request: %w", err)
}
+ if c != nil {
+ c.Set(OpenAIParsedRequestBodyKey, reqBody)
+ }
return reqBody, nil
}
diff --git a/backend/internal/service/openai_gateway_service_hotpath_test.go b/backend/internal/service/openai_gateway_service_hotpath_test.go
index 6b11831f..f73c06c5 100644
--- a/backend/internal/service/openai_gateway_service_hotpath_test.go
+++ b/backend/internal/service/openai_gateway_service_hotpath_test.go
@@ -123,3 +123,19 @@ func TestGetOpenAIRequestBodyMap_ParseErrorWithoutCache(t *testing.T) {
require.Error(t, err)
require.Contains(t, err.Error(), "parse request")
}
+
+func TestGetOpenAIRequestBodyMap_WriteBackContextCache(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+
+ got, err := getOpenAIRequestBodyMap(c, []byte(`{"model":"gpt-5","stream":true}`))
+ require.NoError(t, err)
+ require.Equal(t, "gpt-5", got["model"])
+
+ cached, ok := c.Get(OpenAIParsedRequestBodyKey)
+ require.True(t, ok)
+ cachedMap, ok := cached.(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, got, cachedMap)
+}
diff --git a/backend/internal/service/openai_gateway_service_test.go b/backend/internal/service/openai_gateway_service_test.go
index 226648e4..89443b69 100644
--- a/backend/internal/service/openai_gateway_service_test.go
+++ b/backend/internal/service/openai_gateway_service_test.go
@@ -5,6 +5,7 @@ import (
"bytes"
"context"
"errors"
+ "fmt"
"io"
"net/http"
"net/http/httptest"
@@ -13,6 +14,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/cespare/xxhash/v2"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
@@ -166,6 +168,54 @@ func TestOpenAIGatewayService_GenerateSessionHash_Priority(t *testing.T) {
}
}
+func TestOpenAIGatewayService_GenerateSessionHash_UsesXXHash64(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
+
+ c.Request.Header.Set("session_id", "sess-fixed-value")
+ svc := &OpenAIGatewayService{}
+
+ got := svc.GenerateSessionHash(c, nil)
+ want := fmt.Sprintf("%016x", xxhash.Sum64String("sess-fixed-value"))
+ require.Equal(t, want, got)
+}
+
+func TestOpenAIGatewayService_GenerateSessionHash_AttachesLegacyHashToContext(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
+
+ c.Request.Header.Set("session_id", "sess-legacy-check")
+ svc := &OpenAIGatewayService{}
+
+ sessionHash := svc.GenerateSessionHash(c, nil)
+ require.NotEmpty(t, sessionHash)
+ require.NotNil(t, c.Request)
+ require.NotNil(t, c.Request.Context())
+ require.NotEmpty(t, openAILegacySessionHashFromContext(c.Request.Context()))
+}
+
+func TestOpenAIGatewayService_GenerateSessionHashWithFallback(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
+
+ svc := &OpenAIGatewayService{}
+ seed := "openai_ws_ingress:9:100:200"
+
+ got := svc.GenerateSessionHashWithFallback(c, []byte(`{}`), seed)
+ want := fmt.Sprintf("%016x", xxhash.Sum64String(seed))
+ require.Equal(t, want, got)
+ require.NotEmpty(t, openAILegacySessionHashFromContext(c.Request.Context()))
+
+ empty := svc.GenerateSessionHashWithFallback(c, []byte(`{}`), " ")
+ require.Equal(t, "", empty)
+}
+
func (c stubConcurrencyCache) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
if c.waitCounts != nil {
if count, ok := c.waitCounts[accountID]; ok {
diff --git a/backend/internal/service/openai_json_optimization_benchmark_test.go b/backend/internal/service/openai_json_optimization_benchmark_test.go
new file mode 100644
index 00000000..1737804b
--- /dev/null
+++ b/backend/internal/service/openai_json_optimization_benchmark_test.go
@@ -0,0 +1,357 @@
+package service
+
+import (
+ "encoding/json"
+ "strconv"
+ "strings"
+ "testing"
+
+ "github.com/tidwall/gjson"
+)
+
+var (
+ benchmarkToolContinuationBoolSink bool
+ benchmarkWSParseStringSink string
+ benchmarkWSParseMapSink map[string]any
+ benchmarkUsageSink OpenAIUsage
+)
+
+func BenchmarkToolContinuationValidationLegacy(b *testing.B) {
+ reqBody := benchmarkToolContinuationRequestBody()
+
+ b.ReportAllocs()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ benchmarkToolContinuationBoolSink = legacyValidateFunctionCallOutputContext(reqBody)
+ }
+}
+
+func BenchmarkToolContinuationValidationOptimized(b *testing.B) {
+ reqBody := benchmarkToolContinuationRequestBody()
+
+ b.ReportAllocs()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ benchmarkToolContinuationBoolSink = optimizedValidateFunctionCallOutputContext(reqBody)
+ }
+}
+
+func BenchmarkWSIngressPayloadParseLegacy(b *testing.B) {
+ raw := benchmarkWSIngressPayloadBytes()
+
+ b.ReportAllocs()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ eventType, model, promptCacheKey, previousResponseID, payload, err := legacyParseWSIngressPayload(raw)
+ if err == nil {
+ benchmarkWSParseStringSink = eventType + model + promptCacheKey + previousResponseID
+ benchmarkWSParseMapSink = payload
+ }
+ }
+}
+
+func BenchmarkWSIngressPayloadParseOptimized(b *testing.B) {
+ raw := benchmarkWSIngressPayloadBytes()
+
+ b.ReportAllocs()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ eventType, model, promptCacheKey, previousResponseID, payload, err := optimizedParseWSIngressPayload(raw)
+ if err == nil {
+ benchmarkWSParseStringSink = eventType + model + promptCacheKey + previousResponseID
+ benchmarkWSParseMapSink = payload
+ }
+ }
+}
+
+func BenchmarkOpenAIUsageExtractLegacy(b *testing.B) {
+ body := benchmarkOpenAIUsageJSONBytes()
+
+ b.ReportAllocs()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ usage, ok := legacyExtractOpenAIUsageFromJSONBytes(body)
+ if ok {
+ benchmarkUsageSink = usage
+ }
+ }
+}
+
+func BenchmarkOpenAIUsageExtractOptimized(b *testing.B) {
+ body := benchmarkOpenAIUsageJSONBytes()
+
+ b.ReportAllocs()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ usage, ok := extractOpenAIUsageFromJSONBytes(body)
+ if ok {
+ benchmarkUsageSink = usage
+ }
+ }
+}
+
+func benchmarkToolContinuationRequestBody() map[string]any {
+ input := make([]any, 0, 64)
+ for i := 0; i < 24; i++ {
+ input = append(input, map[string]any{
+ "type": "text",
+ "text": "benchmark text",
+ })
+ }
+ for i := 0; i < 10; i++ {
+ callID := "call_" + strconv.Itoa(i)
+ input = append(input, map[string]any{
+ "type": "tool_call",
+ "call_id": callID,
+ })
+ input = append(input, map[string]any{
+ "type": "function_call_output",
+ "call_id": callID,
+ })
+ input = append(input, map[string]any{
+ "type": "item_reference",
+ "id": callID,
+ })
+ }
+ return map[string]any{
+ "model": "gpt-5.3-codex",
+ "input": input,
+ }
+}
+
+func benchmarkWSIngressPayloadBytes() []byte {
+ return []byte(`{"type":"response.create","model":"gpt-5.3-codex","prompt_cache_key":"cache_bench","previous_response_id":"resp_prev_bench","input":[{"type":"message","role":"user","content":[{"type":"input_text","text":"hello"}]}]}`)
+}
+
+func benchmarkOpenAIUsageJSONBytes() []byte {
+ return []byte(`{"id":"resp_bench","object":"response","model":"gpt-5.3-codex","usage":{"input_tokens":3210,"output_tokens":987,"input_tokens_details":{"cached_tokens":456}}}`)
+}
+
+func legacyValidateFunctionCallOutputContext(reqBody map[string]any) bool {
+ if !legacyHasFunctionCallOutput(reqBody) {
+ return true
+ }
+ previousResponseID, _ := reqBody["previous_response_id"].(string)
+ if strings.TrimSpace(previousResponseID) != "" {
+ return true
+ }
+ if legacyHasToolCallContext(reqBody) {
+ return true
+ }
+ if legacyHasFunctionCallOutputMissingCallID(reqBody) {
+ return false
+ }
+ callIDs := legacyFunctionCallOutputCallIDs(reqBody)
+ return legacyHasItemReferenceForCallIDs(reqBody, callIDs)
+}
+
+func optimizedValidateFunctionCallOutputContext(reqBody map[string]any) bool {
+ validation := ValidateFunctionCallOutputContext(reqBody)
+ if !validation.HasFunctionCallOutput {
+ return true
+ }
+ previousResponseID, _ := reqBody["previous_response_id"].(string)
+ if strings.TrimSpace(previousResponseID) != "" {
+ return true
+ }
+ if validation.HasToolCallContext {
+ return true
+ }
+ if validation.HasFunctionCallOutputMissingCallID {
+ return false
+ }
+ return validation.HasItemReferenceForAllCallIDs
+}
+
+func legacyHasFunctionCallOutput(reqBody map[string]any) bool {
+ if reqBody == nil {
+ return false
+ }
+ input, ok := reqBody["input"].([]any)
+ if !ok {
+ return false
+ }
+ for _, item := range input {
+ itemMap, ok := item.(map[string]any)
+ if !ok {
+ continue
+ }
+ itemType, _ := itemMap["type"].(string)
+ if itemType == "function_call_output" {
+ return true
+ }
+ }
+ return false
+}
+
+func legacyHasToolCallContext(reqBody map[string]any) bool {
+ if reqBody == nil {
+ return false
+ }
+ input, ok := reqBody["input"].([]any)
+ if !ok {
+ return false
+ }
+ for _, item := range input {
+ itemMap, ok := item.(map[string]any)
+ if !ok {
+ continue
+ }
+ itemType, _ := itemMap["type"].(string)
+ if itemType != "tool_call" && itemType != "function_call" {
+ continue
+ }
+ if callID, ok := itemMap["call_id"].(string); ok && strings.TrimSpace(callID) != "" {
+ return true
+ }
+ }
+ return false
+}
+
+func legacyFunctionCallOutputCallIDs(reqBody map[string]any) []string {
+ if reqBody == nil {
+ return nil
+ }
+ input, ok := reqBody["input"].([]any)
+ if !ok {
+ return nil
+ }
+ ids := make(map[string]struct{})
+ for _, item := range input {
+ itemMap, ok := item.(map[string]any)
+ if !ok {
+ continue
+ }
+ itemType, _ := itemMap["type"].(string)
+ if itemType != "function_call_output" {
+ continue
+ }
+ if callID, ok := itemMap["call_id"].(string); ok && strings.TrimSpace(callID) != "" {
+ ids[callID] = struct{}{}
+ }
+ }
+ if len(ids) == 0 {
+ return nil
+ }
+ callIDs := make([]string, 0, len(ids))
+ for id := range ids {
+ callIDs = append(callIDs, id)
+ }
+ return callIDs
+}
+
+func legacyHasFunctionCallOutputMissingCallID(reqBody map[string]any) bool {
+ if reqBody == nil {
+ return false
+ }
+ input, ok := reqBody["input"].([]any)
+ if !ok {
+ return false
+ }
+ for _, item := range input {
+ itemMap, ok := item.(map[string]any)
+ if !ok {
+ continue
+ }
+ itemType, _ := itemMap["type"].(string)
+ if itemType != "function_call_output" {
+ continue
+ }
+ callID, _ := itemMap["call_id"].(string)
+ if strings.TrimSpace(callID) == "" {
+ return true
+ }
+ }
+ return false
+}
+
+func legacyHasItemReferenceForCallIDs(reqBody map[string]any, callIDs []string) bool {
+ if reqBody == nil || len(callIDs) == 0 {
+ return false
+ }
+ input, ok := reqBody["input"].([]any)
+ if !ok {
+ return false
+ }
+ referenceIDs := make(map[string]struct{})
+ for _, item := range input {
+ itemMap, ok := item.(map[string]any)
+ if !ok {
+ continue
+ }
+ itemType, _ := itemMap["type"].(string)
+ if itemType != "item_reference" {
+ continue
+ }
+ idValue, _ := itemMap["id"].(string)
+ idValue = strings.TrimSpace(idValue)
+ if idValue == "" {
+ continue
+ }
+ referenceIDs[idValue] = struct{}{}
+ }
+ if len(referenceIDs) == 0 {
+ return false
+ }
+ for _, callID := range callIDs {
+ if _, ok := referenceIDs[callID]; !ok {
+ return false
+ }
+ }
+ return true
+}
+
+func legacyParseWSIngressPayload(raw []byte) (eventType, model, promptCacheKey, previousResponseID string, payload map[string]any, err error) {
+ values := gjson.GetManyBytes(raw, "type", "model", "prompt_cache_key", "previous_response_id")
+ eventType = strings.TrimSpace(values[0].String())
+ if eventType == "" {
+ eventType = "response.create"
+ }
+ model = strings.TrimSpace(values[1].String())
+ promptCacheKey = strings.TrimSpace(values[2].String())
+ previousResponseID = strings.TrimSpace(values[3].String())
+ payload = make(map[string]any)
+ if err = json.Unmarshal(raw, &payload); err != nil {
+ return "", "", "", "", nil, err
+ }
+ if _, exists := payload["type"]; !exists {
+ payload["type"] = "response.create"
+ }
+ return eventType, model, promptCacheKey, previousResponseID, payload, nil
+}
+
+func optimizedParseWSIngressPayload(raw []byte) (eventType, model, promptCacheKey, previousResponseID string, payload map[string]any, err error) {
+ payload = make(map[string]any)
+ if err = json.Unmarshal(raw, &payload); err != nil {
+ return "", "", "", "", nil, err
+ }
+ eventType = openAIWSPayloadString(payload, "type")
+ if eventType == "" {
+ eventType = "response.create"
+ payload["type"] = eventType
+ }
+ model = openAIWSPayloadString(payload, "model")
+ promptCacheKey = openAIWSPayloadString(payload, "prompt_cache_key")
+ previousResponseID = openAIWSPayloadString(payload, "previous_response_id")
+ return eventType, model, promptCacheKey, previousResponseID, payload, nil
+}
+
+func legacyExtractOpenAIUsageFromJSONBytes(body []byte) (OpenAIUsage, bool) {
+ var response struct {
+ Usage struct {
+ InputTokens int `json:"input_tokens"`
+ OutputTokens int `json:"output_tokens"`
+ InputTokenDetails struct {
+ CachedTokens int `json:"cached_tokens"`
+ } `json:"input_tokens_details"`
+ } `json:"usage"`
+ }
+ if err := json.Unmarshal(body, &response); err != nil {
+ return OpenAIUsage{}, false
+ }
+ return OpenAIUsage{
+ InputTokens: response.Usage.InputTokens,
+ OutputTokens: response.Usage.OutputTokens,
+ CacheReadInputTokens: response.Usage.InputTokenDetails.CachedTokens,
+ }, true
+}
diff --git a/backend/internal/service/openai_oauth_passthrough_test.go b/backend/internal/service/openai_oauth_passthrough_test.go
index 7a996c26..0840d3b1 100644
--- a/backend/internal/service/openai_oauth_passthrough_test.go
+++ b/backend/internal/service/openai_oauth_passthrough_test.go
@@ -515,7 +515,7 @@ func TestOpenAIGatewayService_OAuthPassthrough_NonCodexUAFallbackToCodexUA(t *te
require.NoError(t, err)
require.Equal(t, false, gjson.GetBytes(upstream.lastBody, "store").Bool())
require.Equal(t, true, gjson.GetBytes(upstream.lastBody, "stream").Bool())
- require.Equal(t, "codex_cli_rs/0.98.0", upstream.lastReq.Header.Get("User-Agent"))
+ require.Equal(t, "codex_cli_rs/0.104.0", upstream.lastReq.Header.Get("User-Agent"))
}
func TestOpenAIGatewayService_CodexCLIOnly_RejectsNonCodexClient(t *testing.T) {
diff --git a/backend/internal/service/openai_oauth_service.go b/backend/internal/service/openai_oauth_service.go
index 087ad4ec..07cb5472 100644
--- a/backend/internal/service/openai_oauth_service.go
+++ b/backend/internal/service/openai_oauth_service.go
@@ -5,8 +5,12 @@ import (
"crypto/subtle"
"encoding/json"
"io"
+ "log/slog"
"net/http"
"net/url"
+ "regexp"
+ "sort"
+ "strconv"
"strings"
"time"
@@ -16,6 +20,13 @@ import (
var openAISoraSessionAuthURL = "https://sora.chatgpt.com/api/auth/session"
+var soraSessionCookiePattern = regexp.MustCompile(`(?i)(?:^|[\n\r;])\s*(?:(?:set-cookie|cookie)\s*:\s*)?__Secure-(?:next-auth|authjs)\.session-token(?:\.(\d+))?=([^;\r\n]+)`)
+
+type soraSessionChunk struct {
+ index int
+ value string
+}
+
// OpenAIOAuthService handles OpenAI OAuth authentication flows
type OpenAIOAuthService struct {
sessionStore *openai.SessionStore
@@ -39,7 +50,7 @@ type OpenAIAuthURLResult struct {
}
// GenerateAuthURL generates an OpenAI OAuth authorization URL
-func (s *OpenAIOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64, redirectURI string) (*OpenAIAuthURLResult, error) {
+func (s *OpenAIOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64, redirectURI, platform string) (*OpenAIAuthURLResult, error) {
// Generate PKCE values
state, err := openai.GenerateState()
if err != nil {
@@ -75,11 +86,14 @@ func (s *OpenAIOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64
if redirectURI == "" {
redirectURI = openai.DefaultRedirectURI
}
+ normalizedPlatform := normalizeOpenAIOAuthPlatform(platform)
+ clientID, _ := openai.OAuthClientConfigByPlatform(normalizedPlatform)
// Store session
session := &openai.OAuthSession{
State: state,
CodeVerifier: codeVerifier,
+ ClientID: clientID,
RedirectURI: redirectURI,
ProxyURL: proxyURL,
CreatedAt: time.Now(),
@@ -87,7 +101,7 @@ func (s *OpenAIOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64
s.sessionStore.Set(sessionID, session)
// Build authorization URL
- authURL := openai.BuildAuthorizationURL(state, codeChallenge, redirectURI)
+ authURL := openai.BuildAuthorizationURLForPlatform(state, codeChallenge, redirectURI, normalizedPlatform)
return &OpenAIAuthURLResult{
AuthURL: authURL,
@@ -111,6 +125,7 @@ type OpenAITokenInfo struct {
IDToken string `json:"id_token,omitempty"`
ExpiresIn int64 `json:"expires_in"`
ExpiresAt int64 `json:"expires_at"`
+ ClientID string `json:"client_id,omitempty"`
Email string `json:"email,omitempty"`
ChatGPTAccountID string `json:"chatgpt_account_id,omitempty"`
ChatGPTUserID string `json:"chatgpt_user_id,omitempty"`
@@ -148,9 +163,13 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch
if input.RedirectURI != "" {
redirectURI = input.RedirectURI
}
+ clientID := strings.TrimSpace(session.ClientID)
+ if clientID == "" {
+ clientID = openai.ClientID
+ }
// Exchange code for token
- tokenResp, err := s.oauthClient.ExchangeCode(ctx, input.Code, session.CodeVerifier, redirectURI, proxyURL)
+ tokenResp, err := s.oauthClient.ExchangeCode(ctx, input.Code, session.CodeVerifier, redirectURI, proxyURL, clientID)
if err != nil {
return nil, err
}
@@ -158,8 +177,10 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch
// Parse ID token to get user info
var userInfo *openai.UserInfo
if tokenResp.IDToken != "" {
- claims, err := openai.ParseIDToken(tokenResp.IDToken)
- if err == nil {
+ claims, parseErr := openai.ParseIDToken(tokenResp.IDToken)
+ if parseErr != nil {
+ slog.Warn("openai_oauth_id_token_parse_failed", "error", parseErr)
+ } else {
userInfo = claims.GetUserInfo()
}
}
@@ -173,6 +194,7 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch
IDToken: tokenResp.IDToken,
ExpiresIn: int64(tokenResp.ExpiresIn),
ExpiresAt: time.Now().Unix() + int64(tokenResp.ExpiresIn),
+ ClientID: clientID,
}
if userInfo != nil {
@@ -200,8 +222,10 @@ func (s *OpenAIOAuthService) RefreshTokenWithClientID(ctx context.Context, refre
// Parse ID token to get user info
var userInfo *openai.UserInfo
if tokenResp.IDToken != "" {
- claims, err := openai.ParseIDToken(tokenResp.IDToken)
- if err == nil {
+ claims, parseErr := openai.ParseIDToken(tokenResp.IDToken)
+ if parseErr != nil {
+ slog.Warn("openai_oauth_id_token_parse_failed", "error", parseErr)
+ } else {
userInfo = claims.GetUserInfo()
}
}
@@ -213,6 +237,9 @@ func (s *OpenAIOAuthService) RefreshTokenWithClientID(ctx context.Context, refre
ExpiresIn: int64(tokenResp.ExpiresIn),
ExpiresAt: time.Now().Unix() + int64(tokenResp.ExpiresIn),
}
+ if trimmed := strings.TrimSpace(clientID); trimmed != "" {
+ tokenInfo.ClientID = trimmed
+ }
if userInfo != nil {
tokenInfo.Email = userInfo.Email
@@ -226,6 +253,7 @@ func (s *OpenAIOAuthService) RefreshTokenWithClientID(ctx context.Context, refre
// ExchangeSoraSessionToken exchanges Sora session_token to access_token.
func (s *OpenAIOAuthService) ExchangeSoraSessionToken(ctx context.Context, sessionToken string, proxyID *int64) (*OpenAITokenInfo, error) {
+ sessionToken = normalizeSoraSessionTokenInput(sessionToken)
if strings.TrimSpace(sessionToken) == "" {
return nil, infraerrors.New(http.StatusBadRequest, "SORA_SESSION_TOKEN_REQUIRED", "session_token is required")
}
@@ -287,10 +315,141 @@ func (s *OpenAIOAuthService) ExchangeSoraSessionToken(ctx context.Context, sessi
AccessToken: strings.TrimSpace(sessionResp.AccessToken),
ExpiresIn: expiresIn,
ExpiresAt: expiresAt,
+ ClientID: openai.SoraClientID,
Email: strings.TrimSpace(sessionResp.User.Email),
}, nil
}
+func normalizeSoraSessionTokenInput(raw string) string {
+ trimmed := strings.TrimSpace(raw)
+ if trimmed == "" {
+ return ""
+ }
+
+ matches := soraSessionCookiePattern.FindAllStringSubmatch(trimmed, -1)
+ if len(matches) == 0 {
+ return sanitizeSessionToken(trimmed)
+ }
+
+ chunkMatches := make([]soraSessionChunk, 0, len(matches))
+ singleValues := make([]string, 0, len(matches))
+
+ for _, match := range matches {
+ if len(match) < 3 {
+ continue
+ }
+
+ value := sanitizeSessionToken(match[2])
+ if value == "" {
+ continue
+ }
+
+ if strings.TrimSpace(match[1]) == "" {
+ singleValues = append(singleValues, value)
+ continue
+ }
+
+ idx, err := strconv.Atoi(strings.TrimSpace(match[1]))
+ if err != nil || idx < 0 {
+ continue
+ }
+ chunkMatches = append(chunkMatches, soraSessionChunk{
+ index: idx,
+ value: value,
+ })
+ }
+
+ if merged := mergeLatestSoraSessionChunks(chunkMatches); merged != "" {
+ return merged
+ }
+
+ if len(singleValues) > 0 {
+ return singleValues[len(singleValues)-1]
+ }
+
+ return ""
+}
+
+func mergeSoraSessionChunkSegment(chunks []soraSessionChunk, requiredMaxIndex int, requireComplete bool) string {
+ if len(chunks) == 0 {
+ return ""
+ }
+
+ byIndex := make(map[int]string, len(chunks))
+ for _, chunk := range chunks {
+ byIndex[chunk.index] = chunk.value
+ }
+
+ if _, ok := byIndex[0]; !ok {
+ return ""
+ }
+ if requireComplete {
+ for idx := 0; idx <= requiredMaxIndex; idx++ {
+ if _, ok := byIndex[idx]; !ok {
+ return ""
+ }
+ }
+ }
+
+ orderedIndexes := make([]int, 0, len(byIndex))
+ for idx := range byIndex {
+ orderedIndexes = append(orderedIndexes, idx)
+ }
+ sort.Ints(orderedIndexes)
+
+ var builder strings.Builder
+ for _, idx := range orderedIndexes {
+ if _, err := builder.WriteString(byIndex[idx]); err != nil {
+ return ""
+ }
+ }
+ return sanitizeSessionToken(builder.String())
+}
+
+func mergeLatestSoraSessionChunks(chunks []soraSessionChunk) string {
+ if len(chunks) == 0 {
+ return ""
+ }
+
+ requiredMaxIndex := 0
+ for _, chunk := range chunks {
+ if chunk.index > requiredMaxIndex {
+ requiredMaxIndex = chunk.index
+ }
+ }
+
+ groupStarts := make([]int, 0, len(chunks))
+ for idx, chunk := range chunks {
+ if chunk.index == 0 {
+ groupStarts = append(groupStarts, idx)
+ }
+ }
+
+ if len(groupStarts) == 0 {
+ return mergeSoraSessionChunkSegment(chunks, requiredMaxIndex, false)
+ }
+
+ for i := len(groupStarts) - 1; i >= 0; i-- {
+ start := groupStarts[i]
+ end := len(chunks)
+ if i+1 < len(groupStarts) {
+ end = groupStarts[i+1]
+ }
+ if merged := mergeSoraSessionChunkSegment(chunks[start:end], requiredMaxIndex, true); merged != "" {
+ return merged
+ }
+ }
+
+ return mergeSoraSessionChunkSegment(chunks, requiredMaxIndex, false)
+}
+
+func sanitizeSessionToken(raw string) string {
+ token := strings.TrimSpace(raw)
+ token = strings.Trim(token, "\"'`")
+ token = strings.TrimSuffix(token, ";")
+ return strings.TrimSpace(token)
+}
+
// RefreshAccountToken refreshes token for an OpenAI/Sora OAuth account
func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*OpenAITokenInfo, error) {
if account.Platform != PlatformOpenAI && account.Platform != PlatformSora {
@@ -322,9 +481,12 @@ func (s *OpenAIOAuthService) BuildAccountCredentials(tokenInfo *OpenAITokenInfo)
expiresAt := time.Unix(tokenInfo.ExpiresAt, 0).Format(time.RFC3339)
creds := map[string]any{
- "access_token": tokenInfo.AccessToken,
- "refresh_token": tokenInfo.RefreshToken,
- "expires_at": expiresAt,
+ "access_token": tokenInfo.AccessToken,
+ "expires_at": expiresAt,
+ }
+ // 仅在刷新响应返回了新的 refresh_token 时才更新,防止用空值覆盖已有令牌
+ if strings.TrimSpace(tokenInfo.RefreshToken) != "" {
+ creds["refresh_token"] = tokenInfo.RefreshToken
}
if tokenInfo.IDToken != "" {
@@ -342,6 +504,9 @@ func (s *OpenAIOAuthService) BuildAccountCredentials(tokenInfo *OpenAITokenInfo)
if tokenInfo.OrganizationID != "" {
creds["organization_id"] = tokenInfo.OrganizationID
}
+ if strings.TrimSpace(tokenInfo.ClientID) != "" {
+ creds["client_id"] = strings.TrimSpace(tokenInfo.ClientID)
+ }
return creds
}
@@ -377,3 +542,12 @@ func newOpenAIOAuthHTTPClient(proxyURL string) *http.Client {
Transport: transport,
}
}
+
+func normalizeOpenAIOAuthPlatform(platform string) string {
+ switch strings.ToLower(strings.TrimSpace(platform)) {
+ case PlatformSora:
+ return openai.OAuthPlatformSora
+ default:
+ return openai.OAuthPlatformOpenAI
+ }
+}
diff --git a/backend/internal/service/openai_oauth_service_auth_url_test.go b/backend/internal/service/openai_oauth_service_auth_url_test.go
new file mode 100644
index 00000000..5f26903d
--- /dev/null
+++ b/backend/internal/service/openai_oauth_service_auth_url_test.go
@@ -0,0 +1,67 @@
+package service
+
+import (
+ "context"
+ "errors"
+ "net/url"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
+ "github.com/stretchr/testify/require"
+)
+
+type openaiOAuthClientAuthURLStub struct{}
+
+func (s *openaiOAuthClientAuthURLStub) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL, clientID string) (*openai.TokenResponse, error) {
+ return nil, errors.New("not implemented")
+}
+
+func (s *openaiOAuthClientAuthURLStub) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) {
+ return nil, errors.New("not implemented")
+}
+
+func (s *openaiOAuthClientAuthURLStub) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) {
+ return nil, errors.New("not implemented")
+}
+
+func TestOpenAIOAuthService_GenerateAuthURL_OpenAIKeepsCodexFlow(t *testing.T) {
+ svc := NewOpenAIOAuthService(nil, &openaiOAuthClientAuthURLStub{})
+ defer svc.Stop()
+
+ result, err := svc.GenerateAuthURL(context.Background(), nil, "", PlatformOpenAI)
+ require.NoError(t, err)
+ require.NotEmpty(t, result.AuthURL)
+ require.NotEmpty(t, result.SessionID)
+
+ parsed, err := url.Parse(result.AuthURL)
+ require.NoError(t, err)
+ q := parsed.Query()
+ require.Equal(t, openai.ClientID, q.Get("client_id"))
+ require.Equal(t, "true", q.Get("codex_cli_simplified_flow"))
+
+ session, ok := svc.sessionStore.Get(result.SessionID)
+ require.True(t, ok)
+ require.Equal(t, openai.ClientID, session.ClientID)
+}
+
+// TestOpenAIOAuthService_GenerateAuthURL_SoraUsesCodexClient 验证 Sora 平台复用 Codex CLI 的
+// client_id(支持 localhost redirect_uri),但不启用 codex_cli_simplified_flow。
+func TestOpenAIOAuthService_GenerateAuthURL_SoraUsesCodexClient(t *testing.T) {
+ svc := NewOpenAIOAuthService(nil, &openaiOAuthClientAuthURLStub{})
+ defer svc.Stop()
+
+ result, err := svc.GenerateAuthURL(context.Background(), nil, "", PlatformSora)
+ require.NoError(t, err)
+ require.NotEmpty(t, result.AuthURL)
+ require.NotEmpty(t, result.SessionID)
+
+ parsed, err := url.Parse(result.AuthURL)
+ require.NoError(t, err)
+ q := parsed.Query()
+ require.Equal(t, openai.ClientID, q.Get("client_id"))
+ require.Empty(t, q.Get("codex_cli_simplified_flow"))
+
+ session, ok := svc.sessionStore.Get(result.SessionID)
+ require.True(t, ok)
+ require.Equal(t, openai.ClientID, session.ClientID)
+}
diff --git a/backend/internal/service/openai_oauth_service_sora_session_test.go b/backend/internal/service/openai_oauth_service_sora_session_test.go
index fb76f6c1..08da8557 100644
--- a/backend/internal/service/openai_oauth_service_sora_session_test.go
+++ b/backend/internal/service/openai_oauth_service_sora_session_test.go
@@ -5,6 +5,7 @@ import (
"errors"
"net/http"
"net/http/httptest"
+ "strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
@@ -13,7 +14,7 @@ import (
type openaiOAuthClientNoopStub struct{}
-func (s *openaiOAuthClientNoopStub) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) {
+func (s *openaiOAuthClientNoopStub) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL, clientID string) (*openai.TokenResponse, error) {
return nil, errors.New("not implemented")
}
@@ -67,3 +68,106 @@ func TestOpenAIOAuthService_ExchangeSoraSessionToken_MissingAccessToken(t *testi
require.Error(t, err)
require.Contains(t, err.Error(), "missing access token")
}
+
+func TestOpenAIOAuthService_ExchangeSoraSessionToken_AcceptsSetCookieLine(t *testing.T) {
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ require.Equal(t, http.MethodGet, r.Method)
+ require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=st-cookie-value")
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`))
+ }))
+ defer server.Close()
+
+ origin := openAISoraSessionAuthURL
+ openAISoraSessionAuthURL = server.URL
+ defer func() { openAISoraSessionAuthURL = origin }()
+
+ svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{})
+ defer svc.Stop()
+
+ raw := "__Secure-next-auth.session-token.0=st-cookie-value; Domain=.chatgpt.com; Path=/; HttpOnly; Secure; SameSite=Lax"
+ info, err := svc.ExchangeSoraSessionToken(context.Background(), raw, nil)
+ require.NoError(t, err)
+ require.Equal(t, "at-token", info.AccessToken)
+}
+
+func TestOpenAIOAuthService_ExchangeSoraSessionToken_MergesChunkedSetCookieLines(t *testing.T) {
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ require.Equal(t, http.MethodGet, r.Method)
+ require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=chunk-0chunk-1")
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`))
+ }))
+ defer server.Close()
+
+ origin := openAISoraSessionAuthURL
+ openAISoraSessionAuthURL = server.URL
+ defer func() { openAISoraSessionAuthURL = origin }()
+
+ svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{})
+ defer svc.Stop()
+
+ raw := strings.Join([]string{
+ "Set-Cookie: __Secure-next-auth.session-token.1=chunk-1; Path=/; HttpOnly",
+ "Set-Cookie: __Secure-next-auth.session-token.0=chunk-0; Path=/; HttpOnly",
+ }, "\n")
+ info, err := svc.ExchangeSoraSessionToken(context.Background(), raw, nil)
+ require.NoError(t, err)
+ require.Equal(t, "at-token", info.AccessToken)
+}
+
+func TestOpenAIOAuthService_ExchangeSoraSessionToken_PrefersLatestDuplicateChunks(t *testing.T) {
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ require.Equal(t, http.MethodGet, r.Method)
+ require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=new-0new-1")
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`))
+ }))
+ defer server.Close()
+
+ origin := openAISoraSessionAuthURL
+ openAISoraSessionAuthURL = server.URL
+ defer func() { openAISoraSessionAuthURL = origin }()
+
+ svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{})
+ defer svc.Stop()
+
+ raw := strings.Join([]string{
+ "Set-Cookie: __Secure-next-auth.session-token.0=old-0; Path=/; HttpOnly",
+ "Set-Cookie: __Secure-next-auth.session-token.1=old-1; Path=/; HttpOnly",
+ "Set-Cookie: __Secure-next-auth.session-token.0=new-0; Path=/; HttpOnly",
+ "Set-Cookie: __Secure-next-auth.session-token.1=new-1; Path=/; HttpOnly",
+ }, "\n")
+ info, err := svc.ExchangeSoraSessionToken(context.Background(), raw, nil)
+ require.NoError(t, err)
+ require.Equal(t, "at-token", info.AccessToken)
+}
+
+func TestOpenAIOAuthService_ExchangeSoraSessionToken_UsesLatestCompleteChunkGroup(t *testing.T) {
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ require.Equal(t, http.MethodGet, r.Method)
+ require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=ok-0ok-1")
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`))
+ }))
+ defer server.Close()
+
+ origin := openAISoraSessionAuthURL
+ openAISoraSessionAuthURL = server.URL
+ defer func() { openAISoraSessionAuthURL = origin }()
+
+ svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{})
+ defer svc.Stop()
+
+ raw := strings.Join([]string{
+ "set-cookie",
+ "__Secure-next-auth.session-token.0=ok-0; Domain=.chatgpt.com; Path=/",
+ "set-cookie",
+ "__Secure-next-auth.session-token.1=ok-1; Domain=.chatgpt.com; Path=/",
+ "set-cookie",
+ "__Secure-next-auth.session-token.0=partial-0; Domain=.chatgpt.com; Path=/",
+ }, "\n")
+ info, err := svc.ExchangeSoraSessionToken(context.Background(), raw, nil)
+ require.NoError(t, err)
+ require.Equal(t, "at-token", info.AccessToken)
+}
diff --git a/backend/internal/service/openai_oauth_service_state_test.go b/backend/internal/service/openai_oauth_service_state_test.go
index 0a2a195f..29252328 100644
--- a/backend/internal/service/openai_oauth_service_state_test.go
+++ b/backend/internal/service/openai_oauth_service_state_test.go
@@ -13,10 +13,12 @@ import (
type openaiOAuthClientStateStub struct {
exchangeCalled int32
+ lastClientID string
}
-func (s *openaiOAuthClientStateStub) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) {
+func (s *openaiOAuthClientStateStub) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL, clientID string) (*openai.TokenResponse, error) {
atomic.AddInt32(&s.exchangeCalled, 1)
+ s.lastClientID = clientID
return &openai.TokenResponse{
AccessToken: "at",
RefreshToken: "rt",
@@ -95,6 +97,8 @@ func TestOpenAIOAuthService_ExchangeCode_StateMatch(t *testing.T) {
require.NoError(t, err)
require.NotNil(t, info)
require.Equal(t, "at", info.AccessToken)
+ require.Equal(t, openai.ClientID, info.ClientID)
+ require.Equal(t, openai.ClientID, client.lastClientID)
require.Equal(t, int32(1), atomic.LoadInt32(&client.exchangeCalled))
_, ok := svc.sessionStore.Get("sid")
diff --git a/backend/internal/service/openai_previous_response_id.go b/backend/internal/service/openai_previous_response_id.go
new file mode 100644
index 00000000..95865086
--- /dev/null
+++ b/backend/internal/service/openai_previous_response_id.go
@@ -0,0 +1,37 @@
+package service
+
+import (
+ "regexp"
+ "strings"
+)
+
+const (
+ OpenAIPreviousResponseIDKindEmpty = "empty"
+ OpenAIPreviousResponseIDKindResponseID = "response_id"
+ OpenAIPreviousResponseIDKindMessageID = "message_id"
+ OpenAIPreviousResponseIDKindUnknown = "unknown"
+)
+
+var (
+ openAIResponseIDPattern = regexp.MustCompile(`^resp_[A-Za-z0-9_-]{1,256}$`)
+ openAIMessageIDPattern = regexp.MustCompile(`^(msg|message|item|chatcmpl)_[A-Za-z0-9_-]{1,256}$`)
+)
+
+// ClassifyOpenAIPreviousResponseIDKind classifies previous_response_id to improve diagnostics.
+func ClassifyOpenAIPreviousResponseIDKind(id string) string {
+ trimmed := strings.TrimSpace(id)
+ if trimmed == "" {
+ return OpenAIPreviousResponseIDKindEmpty
+ }
+ if openAIResponseIDPattern.MatchString(trimmed) {
+ return OpenAIPreviousResponseIDKindResponseID
+ }
+ if openAIMessageIDPattern.MatchString(strings.ToLower(trimmed)) {
+ return OpenAIPreviousResponseIDKindMessageID
+ }
+ return OpenAIPreviousResponseIDKindUnknown
+}
+
+func IsOpenAIPreviousResponseIDLikelyMessageID(id string) bool {
+ return ClassifyOpenAIPreviousResponseIDKind(id) == OpenAIPreviousResponseIDKindMessageID
+}
diff --git a/backend/internal/service/openai_previous_response_id_test.go b/backend/internal/service/openai_previous_response_id_test.go
new file mode 100644
index 00000000..7867b864
--- /dev/null
+++ b/backend/internal/service/openai_previous_response_id_test.go
@@ -0,0 +1,34 @@
+package service
+
+import "testing"
+
+func TestClassifyOpenAIPreviousResponseIDKind(t *testing.T) {
+ tests := []struct {
+ name string
+ id string
+ want string
+ }{
+ {name: "empty", id: " ", want: OpenAIPreviousResponseIDKindEmpty},
+ {name: "response_id", id: "resp_0906a621bc423a8d0169a108637ef88197b74b0e2f37ba358f", want: OpenAIPreviousResponseIDKindResponseID},
+ {name: "message_id", id: "msg_123456", want: OpenAIPreviousResponseIDKindMessageID},
+ {name: "item_id", id: "item_abcdef", want: OpenAIPreviousResponseIDKindMessageID},
+ {name: "unknown", id: "foo_123456", want: OpenAIPreviousResponseIDKindUnknown},
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ if got := ClassifyOpenAIPreviousResponseIDKind(tc.id); got != tc.want {
+ t.Fatalf("ClassifyOpenAIPreviousResponseIDKind(%q)=%q want=%q", tc.id, got, tc.want)
+ }
+ })
+ }
+}
+
+func TestIsOpenAIPreviousResponseIDLikelyMessageID(t *testing.T) {
+ if !IsOpenAIPreviousResponseIDLikelyMessageID("msg_123") {
+ t.Fatal("expected msg_123 to be identified as message id")
+ }
+ if IsOpenAIPreviousResponseIDLikelyMessageID("resp_123") {
+ t.Fatal("expected resp_123 not to be identified as message id")
+ }
+}
diff --git a/backend/internal/service/openai_sticky_compat.go b/backend/internal/service/openai_sticky_compat.go
new file mode 100644
index 00000000..e897debc
--- /dev/null
+++ b/backend/internal/service/openai_sticky_compat.go
@@ -0,0 +1,214 @@
+package service
+
+import (
+ "context"
+ "crypto/sha256"
+ "encoding/hex"
+ "fmt"
+ "strings"
+ "sync/atomic"
+ "time"
+
+ "github.com/cespare/xxhash/v2"
+ "github.com/gin-gonic/gin"
+)
+
+type openAILegacySessionHashContextKey struct{}
+
+var openAILegacySessionHashKey = openAILegacySessionHashContextKey{}
+
+var (
+ openAIStickyLegacyReadFallbackTotal atomic.Int64
+ openAIStickyLegacyReadFallbackHit atomic.Int64
+ openAIStickyLegacyDualWriteTotal atomic.Int64
+)
+
+func openAIStickyCompatStats() (legacyReadFallbackTotal, legacyReadFallbackHit, legacyDualWriteTotal int64) {
+ return openAIStickyLegacyReadFallbackTotal.Load(),
+ openAIStickyLegacyReadFallbackHit.Load(),
+ openAIStickyLegacyDualWriteTotal.Load()
+}
+
+func deriveOpenAISessionHashes(sessionID string) (currentHash string, legacyHash string) {
+ normalized := strings.TrimSpace(sessionID)
+ if normalized == "" {
+ return "", ""
+ }
+
+ currentHash = fmt.Sprintf("%016x", xxhash.Sum64String(normalized))
+ sum := sha256.Sum256([]byte(normalized))
+ legacyHash = hex.EncodeToString(sum[:])
+ return currentHash, legacyHash
+}
+
+func withOpenAILegacySessionHash(ctx context.Context, legacyHash string) context.Context {
+ if ctx == nil {
+ return nil
+ }
+ trimmed := strings.TrimSpace(legacyHash)
+ if trimmed == "" {
+ return ctx
+ }
+ return context.WithValue(ctx, openAILegacySessionHashKey, trimmed)
+}
+
+func openAILegacySessionHashFromContext(ctx context.Context) string {
+ if ctx == nil {
+ return ""
+ }
+ value, _ := ctx.Value(openAILegacySessionHashKey).(string)
+ return strings.TrimSpace(value)
+}
+
+func attachOpenAILegacySessionHashToGin(c *gin.Context, legacyHash string) {
+ if c == nil || c.Request == nil {
+ return
+ }
+ c.Request = c.Request.WithContext(withOpenAILegacySessionHash(c.Request.Context(), legacyHash))
+}
+
+func (s *OpenAIGatewayService) openAISessionHashReadOldFallbackEnabled() bool {
+ if s == nil || s.cfg == nil {
+ return true
+ }
+ return s.cfg.Gateway.OpenAIWS.SessionHashReadOldFallback
+}
+
+func (s *OpenAIGatewayService) openAISessionHashDualWriteOldEnabled() bool {
+ if s == nil || s.cfg == nil {
+ return true
+ }
+ return s.cfg.Gateway.OpenAIWS.SessionHashDualWriteOld
+}
+
+func (s *OpenAIGatewayService) openAISessionCacheKey(sessionHash string) string {
+ normalized := strings.TrimSpace(sessionHash)
+ if normalized == "" {
+ return ""
+ }
+ return "openai:" + normalized
+}
+
+func (s *OpenAIGatewayService) openAILegacySessionCacheKey(ctx context.Context, sessionHash string) string {
+ legacyHash := openAILegacySessionHashFromContext(ctx)
+ if legacyHash == "" {
+ return ""
+ }
+ legacyKey := "openai:" + legacyHash
+ if legacyKey == s.openAISessionCacheKey(sessionHash) {
+ return ""
+ }
+ return legacyKey
+}
+
+func (s *OpenAIGatewayService) openAIStickyLegacyTTL(ttl time.Duration) time.Duration {
+ legacyTTL := ttl
+ if legacyTTL <= 0 {
+ legacyTTL = openaiStickySessionTTL
+ }
+ if legacyTTL > 10*time.Minute {
+ return 10 * time.Minute
+ }
+ return legacyTTL
+}
+
+func (s *OpenAIGatewayService) getStickySessionAccountID(ctx context.Context, groupID *int64, sessionHash string) (int64, error) {
+ if s == nil || s.cache == nil {
+ return 0, nil
+ }
+
+ primaryKey := s.openAISessionCacheKey(sessionHash)
+ if primaryKey == "" {
+ return 0, nil
+ }
+
+ accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), primaryKey)
+ if err == nil && accountID > 0 {
+ return accountID, nil
+ }
+ if !s.openAISessionHashReadOldFallbackEnabled() {
+ return accountID, err
+ }
+
+ legacyKey := s.openAILegacySessionCacheKey(ctx, sessionHash)
+ if legacyKey == "" {
+ return accountID, err
+ }
+
+ openAIStickyLegacyReadFallbackTotal.Add(1)
+ legacyAccountID, legacyErr := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), legacyKey)
+ if legacyErr == nil && legacyAccountID > 0 {
+ openAIStickyLegacyReadFallbackHit.Add(1)
+ return legacyAccountID, nil
+ }
+ return accountID, err
+}
+
+func (s *OpenAIGatewayService) setStickySessionAccountID(ctx context.Context, groupID *int64, sessionHash string, accountID int64, ttl time.Duration) error {
+ if s == nil || s.cache == nil || accountID <= 0 {
+ return nil
+ }
+ primaryKey := s.openAISessionCacheKey(sessionHash)
+ if primaryKey == "" {
+ return nil
+ }
+
+ if err := s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), primaryKey, accountID, ttl); err != nil {
+ return err
+ }
+
+ if !s.openAISessionHashDualWriteOldEnabled() {
+ return nil
+ }
+ legacyKey := s.openAILegacySessionCacheKey(ctx, sessionHash)
+ if legacyKey == "" {
+ return nil
+ }
+ if err := s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), legacyKey, accountID, s.openAIStickyLegacyTTL(ttl)); err != nil {
+ return err
+ }
+ openAIStickyLegacyDualWriteTotal.Add(1)
+ return nil
+}
+
+func (s *OpenAIGatewayService) refreshStickySessionTTL(ctx context.Context, groupID *int64, sessionHash string, ttl time.Duration) error {
+ if s == nil || s.cache == nil {
+ return nil
+ }
+ primaryKey := s.openAISessionCacheKey(sessionHash)
+ if primaryKey == "" {
+ return nil
+ }
+
+ err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), primaryKey, ttl)
+ if !s.openAISessionHashReadOldFallbackEnabled() && !s.openAISessionHashDualWriteOldEnabled() {
+ return err
+ }
+
+ legacyKey := s.openAILegacySessionCacheKey(ctx, sessionHash)
+ if legacyKey != "" {
+ _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), legacyKey, s.openAIStickyLegacyTTL(ttl))
+ }
+ return err
+}
+
+func (s *OpenAIGatewayService) deleteStickySessionAccountID(ctx context.Context, groupID *int64, sessionHash string) error {
+ if s == nil || s.cache == nil {
+ return nil
+ }
+ primaryKey := s.openAISessionCacheKey(sessionHash)
+ if primaryKey == "" {
+ return nil
+ }
+
+ err := s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), primaryKey)
+ if !s.openAISessionHashReadOldFallbackEnabled() && !s.openAISessionHashDualWriteOldEnabled() {
+ return err
+ }
+
+ legacyKey := s.openAILegacySessionCacheKey(ctx, sessionHash)
+ if legacyKey != "" {
+ _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), legacyKey)
+ }
+ return err
+}
diff --git a/backend/internal/service/openai_sticky_compat_test.go b/backend/internal/service/openai_sticky_compat_test.go
new file mode 100644
index 00000000..9f57c358
--- /dev/null
+++ b/backend/internal/service/openai_sticky_compat_test.go
@@ -0,0 +1,96 @@
+package service
+
+import (
+ "context"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
+ "github.com/stretchr/testify/require"
+)
+
+func TestGetStickySessionAccountID_FallbackToLegacyKey(t *testing.T) {
+ beforeFallbackTotal, beforeFallbackHit, _ := openAIStickyCompatStats()
+
+ cache := &stubGatewayCache{
+ sessionBindings: map[string]int64{
+ "openai:legacy-hash": 42,
+ },
+ }
+ svc := &OpenAIGatewayService{
+ cache: cache,
+ cfg: &config.Config{
+ Gateway: config.GatewayConfig{
+ OpenAIWS: config.GatewayOpenAIWSConfig{
+ SessionHashReadOldFallback: true,
+ },
+ },
+ },
+ }
+
+ ctx := withOpenAILegacySessionHash(context.Background(), "legacy-hash")
+ accountID, err := svc.getStickySessionAccountID(ctx, nil, "new-hash")
+ require.NoError(t, err)
+ require.Equal(t, int64(42), accountID)
+
+ afterFallbackTotal, afterFallbackHit, _ := openAIStickyCompatStats()
+ require.Equal(t, beforeFallbackTotal+1, afterFallbackTotal)
+ require.Equal(t, beforeFallbackHit+1, afterFallbackHit)
+}
+
+func TestSetStickySessionAccountID_DualWriteOldEnabled(t *testing.T) {
+ _, _, beforeDualWriteTotal := openAIStickyCompatStats()
+
+ cache := &stubGatewayCache{sessionBindings: map[string]int64{}}
+ svc := &OpenAIGatewayService{
+ cache: cache,
+ cfg: &config.Config{
+ Gateway: config.GatewayConfig{
+ OpenAIWS: config.GatewayOpenAIWSConfig{
+ SessionHashDualWriteOld: true,
+ },
+ },
+ },
+ }
+
+ ctx := withOpenAILegacySessionHash(context.Background(), "legacy-hash")
+ err := svc.setStickySessionAccountID(ctx, nil, "new-hash", 9, openaiStickySessionTTL)
+ require.NoError(t, err)
+ require.Equal(t, int64(9), cache.sessionBindings["openai:new-hash"])
+ require.Equal(t, int64(9), cache.sessionBindings["openai:legacy-hash"])
+
+ _, _, afterDualWriteTotal := openAIStickyCompatStats()
+ require.Equal(t, beforeDualWriteTotal+1, afterDualWriteTotal)
+}
+
+func TestSetStickySessionAccountID_DualWriteOldDisabled(t *testing.T) {
+ cache := &stubGatewayCache{sessionBindings: map[string]int64{}}
+ svc := &OpenAIGatewayService{
+ cache: cache,
+ cfg: &config.Config{
+ Gateway: config.GatewayConfig{
+ OpenAIWS: config.GatewayOpenAIWSConfig{
+ SessionHashDualWriteOld: false,
+ },
+ },
+ },
+ }
+
+ ctx := withOpenAILegacySessionHash(context.Background(), "legacy-hash")
+ err := svc.setStickySessionAccountID(ctx, nil, "new-hash", 9, openaiStickySessionTTL)
+ require.NoError(t, err)
+ require.Equal(t, int64(9), cache.sessionBindings["openai:new-hash"])
+ _, exists := cache.sessionBindings["openai:legacy-hash"]
+ require.False(t, exists)
+}
+
+func TestSnapshotOpenAICompatibilityFallbackMetrics(t *testing.T) {
+ before := SnapshotOpenAICompatibilityFallbackMetrics()
+
+ ctx := context.WithValue(context.Background(), ctxkey.ThinkingEnabled, true)
+ _, _ = ThinkingEnabledFromContext(ctx)
+
+ after := SnapshotOpenAICompatibilityFallbackMetrics()
+ require.GreaterOrEqual(t, after.MetadataLegacyFallbackTotal, before.MetadataLegacyFallbackTotal+1)
+ require.GreaterOrEqual(t, after.MetadataLegacyFallbackThinkingEnabledTotal, before.MetadataLegacyFallbackThinkingEnabledTotal+1)
+}
diff --git a/backend/internal/service/openai_tool_continuation.go b/backend/internal/service/openai_tool_continuation.go
index e59082b2..dea3c172 100644
--- a/backend/internal/service/openai_tool_continuation.go
+++ b/backend/internal/service/openai_tool_continuation.go
@@ -2,6 +2,24 @@ package service
import "strings"
+// ToolContinuationSignals 聚合工具续链相关信号,避免重复遍历 input。
+type ToolContinuationSignals struct {
+ HasFunctionCallOutput bool
+ HasFunctionCallOutputMissingCallID bool
+ HasToolCallContext bool
+ HasItemReference bool
+ HasItemReferenceForAllCallIDs bool
+ FunctionCallOutputCallIDs []string
+}
+
+// FunctionCallOutputValidation 汇总 function_call_output 关联性校验结果。
+type FunctionCallOutputValidation struct {
+ HasFunctionCallOutput bool
+ HasToolCallContext bool
+ HasFunctionCallOutputMissingCallID bool
+ HasItemReferenceForAllCallIDs bool
+}
+
// NeedsToolContinuation 判定请求是否需要工具调用续链处理。
// 满足以下任一信号即视为续链:previous_response_id、input 内包含 function_call_output/item_reference、
// 或显式声明 tools/tool_choice。
@@ -18,107 +36,191 @@ func NeedsToolContinuation(reqBody map[string]any) bool {
if hasToolChoiceSignal(reqBody) {
return true
}
- if inputHasType(reqBody, "function_call_output") {
- return true
+ input, ok := reqBody["input"].([]any)
+ if !ok {
+ return false
}
- if inputHasType(reqBody, "item_reference") {
- return true
+ for _, item := range input {
+ itemMap, ok := item.(map[string]any)
+ if !ok {
+ continue
+ }
+ itemType, _ := itemMap["type"].(string)
+ if itemType == "function_call_output" || itemType == "item_reference" {
+ return true
+ }
}
return false
}
+// AnalyzeToolContinuationSignals 单次遍历 input,提取 function_call_output/tool_call/item_reference 相关信号。
+func AnalyzeToolContinuationSignals(reqBody map[string]any) ToolContinuationSignals {
+ signals := ToolContinuationSignals{}
+ if reqBody == nil {
+ return signals
+ }
+ input, ok := reqBody["input"].([]any)
+ if !ok {
+ return signals
+ }
+
+ var callIDs map[string]struct{}
+ var referenceIDs map[string]struct{}
+
+ for _, item := range input {
+ itemMap, ok := item.(map[string]any)
+ if !ok {
+ continue
+ }
+ itemType, _ := itemMap["type"].(string)
+ switch itemType {
+ case "tool_call", "function_call":
+ callID, _ := itemMap["call_id"].(string)
+ if strings.TrimSpace(callID) != "" {
+ signals.HasToolCallContext = true
+ }
+ case "function_call_output":
+ signals.HasFunctionCallOutput = true
+ callID, _ := itemMap["call_id"].(string)
+ callID = strings.TrimSpace(callID)
+ if callID == "" {
+ signals.HasFunctionCallOutputMissingCallID = true
+ continue
+ }
+ if callIDs == nil {
+ callIDs = make(map[string]struct{})
+ }
+ callIDs[callID] = struct{}{}
+ case "item_reference":
+ signals.HasItemReference = true
+ idValue, _ := itemMap["id"].(string)
+ idValue = strings.TrimSpace(idValue)
+ if idValue == "" {
+ continue
+ }
+ if referenceIDs == nil {
+ referenceIDs = make(map[string]struct{})
+ }
+ referenceIDs[idValue] = struct{}{}
+ }
+ }
+
+ if len(callIDs) == 0 {
+ return signals
+ }
+ signals.FunctionCallOutputCallIDs = make([]string, 0, len(callIDs))
+ allReferenced := len(referenceIDs) > 0
+ for callID := range callIDs {
+ signals.FunctionCallOutputCallIDs = append(signals.FunctionCallOutputCallIDs, callID)
+ if allReferenced {
+ if _, ok := referenceIDs[callID]; !ok {
+ allReferenced = false
+ }
+ }
+ }
+ signals.HasItemReferenceForAllCallIDs = allReferenced
+ return signals
+}
+
+// ValidateFunctionCallOutputContext 为 handler 提供低开销校验结果:
+// 1) 无 function_call_output 直接返回
+// 2) 若已存在 tool_call/function_call 上下文则提前返回
+// 3) 仅在无工具上下文时才构建 call_id / item_reference 集合
+func ValidateFunctionCallOutputContext(reqBody map[string]any) FunctionCallOutputValidation {
+ result := FunctionCallOutputValidation{}
+ if reqBody == nil {
+ return result
+ }
+ input, ok := reqBody["input"].([]any)
+ if !ok {
+ return result
+ }
+
+ for _, item := range input {
+ itemMap, ok := item.(map[string]any)
+ if !ok {
+ continue
+ }
+ itemType, _ := itemMap["type"].(string)
+ switch itemType {
+ case "function_call_output":
+ result.HasFunctionCallOutput = true
+ case "tool_call", "function_call":
+ callID, _ := itemMap["call_id"].(string)
+ if strings.TrimSpace(callID) != "" {
+ result.HasToolCallContext = true
+ }
+ }
+ if result.HasFunctionCallOutput && result.HasToolCallContext {
+ return result
+ }
+ }
+
+ if !result.HasFunctionCallOutput || result.HasToolCallContext {
+ return result
+ }
+
+ callIDs := make(map[string]struct{})
+ referenceIDs := make(map[string]struct{})
+ for _, item := range input {
+ itemMap, ok := item.(map[string]any)
+ if !ok {
+ continue
+ }
+ itemType, _ := itemMap["type"].(string)
+ switch itemType {
+ case "function_call_output":
+ callID, _ := itemMap["call_id"].(string)
+ callID = strings.TrimSpace(callID)
+ if callID == "" {
+ result.HasFunctionCallOutputMissingCallID = true
+ continue
+ }
+ callIDs[callID] = struct{}{}
+ case "item_reference":
+ idValue, _ := itemMap["id"].(string)
+ idValue = strings.TrimSpace(idValue)
+ if idValue == "" {
+ continue
+ }
+ referenceIDs[idValue] = struct{}{}
+ }
+ }
+
+ if len(callIDs) == 0 || len(referenceIDs) == 0 {
+ return result
+ }
+ allReferenced := true
+ for callID := range callIDs {
+ if _, ok := referenceIDs[callID]; !ok {
+ allReferenced = false
+ break
+ }
+ }
+ result.HasItemReferenceForAllCallIDs = allReferenced
+ return result
+}
+
// HasFunctionCallOutput 判断 input 是否包含 function_call_output,用于触发续链校验。
func HasFunctionCallOutput(reqBody map[string]any) bool {
- if reqBody == nil {
- return false
- }
- return inputHasType(reqBody, "function_call_output")
+ return AnalyzeToolContinuationSignals(reqBody).HasFunctionCallOutput
}
// HasToolCallContext 判断 input 是否包含带 call_id 的 tool_call/function_call,
// 用于判断 function_call_output 是否具备可关联的上下文。
func HasToolCallContext(reqBody map[string]any) bool {
- if reqBody == nil {
- return false
- }
- input, ok := reqBody["input"].([]any)
- if !ok {
- return false
- }
- for _, item := range input {
- itemMap, ok := item.(map[string]any)
- if !ok {
- continue
- }
- itemType, _ := itemMap["type"].(string)
- if itemType != "tool_call" && itemType != "function_call" {
- continue
- }
- if callID, ok := itemMap["call_id"].(string); ok && strings.TrimSpace(callID) != "" {
- return true
- }
- }
- return false
+ return AnalyzeToolContinuationSignals(reqBody).HasToolCallContext
}
// FunctionCallOutputCallIDs 提取 input 中 function_call_output 的 call_id 集合。
// 仅返回非空 call_id,用于与 item_reference.id 做匹配校验。
func FunctionCallOutputCallIDs(reqBody map[string]any) []string {
- if reqBody == nil {
- return nil
- }
- input, ok := reqBody["input"].([]any)
- if !ok {
- return nil
- }
- ids := make(map[string]struct{})
- for _, item := range input {
- itemMap, ok := item.(map[string]any)
- if !ok {
- continue
- }
- itemType, _ := itemMap["type"].(string)
- if itemType != "function_call_output" {
- continue
- }
- if callID, ok := itemMap["call_id"].(string); ok && strings.TrimSpace(callID) != "" {
- ids[callID] = struct{}{}
- }
- }
- if len(ids) == 0 {
- return nil
- }
- result := make([]string, 0, len(ids))
- for id := range ids {
- result = append(result, id)
- }
- return result
+ return AnalyzeToolContinuationSignals(reqBody).FunctionCallOutputCallIDs
}
// HasFunctionCallOutputMissingCallID 判断是否存在缺少 call_id 的 function_call_output。
func HasFunctionCallOutputMissingCallID(reqBody map[string]any) bool {
- if reqBody == nil {
- return false
- }
- input, ok := reqBody["input"].([]any)
- if !ok {
- return false
- }
- for _, item := range input {
- itemMap, ok := item.(map[string]any)
- if !ok {
- continue
- }
- itemType, _ := itemMap["type"].(string)
- if itemType != "function_call_output" {
- continue
- }
- callID, _ := itemMap["call_id"].(string)
- if strings.TrimSpace(callID) == "" {
- return true
- }
- }
- return false
+ return AnalyzeToolContinuationSignals(reqBody).HasFunctionCallOutputMissingCallID
}
// HasItemReferenceForCallIDs 判断 item_reference.id 是否覆盖所有 call_id。
@@ -152,32 +254,13 @@ func HasItemReferenceForCallIDs(reqBody map[string]any, callIDs []string) bool {
return false
}
for _, callID := range callIDs {
- if _, ok := referenceIDs[callID]; !ok {
+ if _, ok := referenceIDs[strings.TrimSpace(callID)]; !ok {
return false
}
}
return true
}
-// inputHasType 判断 input 中是否存在指定类型的 item。
-func inputHasType(reqBody map[string]any, want string) bool {
- input, ok := reqBody["input"].([]any)
- if !ok {
- return false
- }
- for _, item := range input {
- itemMap, ok := item.(map[string]any)
- if !ok {
- continue
- }
- itemType, _ := itemMap["type"].(string)
- if itemType == want {
- return true
- }
- }
- return false
-}
-
// hasNonEmptyString 判断字段是否为非空字符串。
func hasNonEmptyString(value any) bool {
stringValue, ok := value.(string)
diff --git a/backend/internal/service/openai_tool_corrector.go b/backend/internal/service/openai_tool_corrector.go
index deec80fa..348723a6 100644
--- a/backend/internal/service/openai_tool_corrector.go
+++ b/backend/internal/service/openai_tool_corrector.go
@@ -1,11 +1,15 @@
package service
import (
- "encoding/json"
+ "bytes"
"fmt"
+ "strconv"
+ "strings"
"sync"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
+ "github.com/tidwall/gjson"
+ "github.com/tidwall/sjson"
)
// codexToolNameMapping 定义 Codex 原生工具名称到 OpenCode 工具名称的映射
@@ -62,169 +66,201 @@ func (c *CodexToolCorrector) CorrectToolCallsInSSEData(data string) (string, boo
if data == "" || data == "\n" {
return data, false
}
+ correctedBytes, corrected := c.CorrectToolCallsInSSEBytes([]byte(data))
+ if !corrected {
+ return data, false
+ }
+ return string(correctedBytes), true
+}
- // 尝试解析 JSON
- var payload map[string]any
- if err := json.Unmarshal([]byte(data), &payload); err != nil {
- // 不是有效的 JSON,直接返回原数据
+// CorrectToolCallsInSSEBytes 修正 SSE JSON 数据中的工具调用(字节路径)。
+// 返回修正后的数据和是否进行了修正。
+func (c *CodexToolCorrector) CorrectToolCallsInSSEBytes(data []byte) ([]byte, bool) {
+ if len(bytes.TrimSpace(data)) == 0 {
+ return data, false
+ }
+ if !mayContainToolCallPayload(data) {
+ return data, false
+ }
+ if !gjson.ValidBytes(data) {
+ // 不是有效 JSON,直接返回原数据
return data, false
}
+ updated := data
corrected := false
-
- // 处理 tool_calls 数组
- if toolCalls, ok := payload["tool_calls"].([]any); ok {
- if c.correctToolCallsArray(toolCalls) {
+ collect := func(changed bool, next []byte) {
+ if changed {
corrected = true
+ updated = next
}
}
- // 处理 function_call 对象
- if functionCall, ok := payload["function_call"].(map[string]any); ok {
- if c.correctFunctionCall(functionCall) {
- corrected = true
- }
+ if next, changed := c.correctToolCallsArrayAtPath(updated, "tool_calls"); changed {
+ collect(changed, next)
+ }
+ if next, changed := c.correctFunctionAtPath(updated, "function_call"); changed {
+ collect(changed, next)
+ }
+ if next, changed := c.correctToolCallsArrayAtPath(updated, "delta.tool_calls"); changed {
+ collect(changed, next)
+ }
+ if next, changed := c.correctFunctionAtPath(updated, "delta.function_call"); changed {
+ collect(changed, next)
}
- // 处理 delta.tool_calls
- if delta, ok := payload["delta"].(map[string]any); ok {
- if toolCalls, ok := delta["tool_calls"].([]any); ok {
- if c.correctToolCallsArray(toolCalls) {
- corrected = true
- }
+ choicesCount := int(gjson.GetBytes(updated, "choices.#").Int())
+ for i := 0; i < choicesCount; i++ {
+ prefix := "choices." + strconv.Itoa(i)
+ if next, changed := c.correctToolCallsArrayAtPath(updated, prefix+".message.tool_calls"); changed {
+ collect(changed, next)
}
- if functionCall, ok := delta["function_call"].(map[string]any); ok {
- if c.correctFunctionCall(functionCall) {
- corrected = true
- }
+ if next, changed := c.correctFunctionAtPath(updated, prefix+".message.function_call"); changed {
+ collect(changed, next)
}
- }
-
- // 处理 choices[].message.tool_calls 和 choices[].delta.tool_calls
- if choices, ok := payload["choices"].([]any); ok {
- for _, choice := range choices {
- if choiceMap, ok := choice.(map[string]any); ok {
- // 处理 message 中的工具调用
- if message, ok := choiceMap["message"].(map[string]any); ok {
- if toolCalls, ok := message["tool_calls"].([]any); ok {
- if c.correctToolCallsArray(toolCalls) {
- corrected = true
- }
- }
- if functionCall, ok := message["function_call"].(map[string]any); ok {
- if c.correctFunctionCall(functionCall) {
- corrected = true
- }
- }
- }
- // 处理 delta 中的工具调用
- if delta, ok := choiceMap["delta"].(map[string]any); ok {
- if toolCalls, ok := delta["tool_calls"].([]any); ok {
- if c.correctToolCallsArray(toolCalls) {
- corrected = true
- }
- }
- if functionCall, ok := delta["function_call"].(map[string]any); ok {
- if c.correctFunctionCall(functionCall) {
- corrected = true
- }
- }
- }
- }
+ if next, changed := c.correctToolCallsArrayAtPath(updated, prefix+".delta.tool_calls"); changed {
+ collect(changed, next)
+ }
+ if next, changed := c.correctFunctionAtPath(updated, prefix+".delta.function_call"); changed {
+ collect(changed, next)
}
}
if !corrected {
return data, false
}
+ return updated, true
+}
- // 序列化回 JSON
- correctedBytes, err := json.Marshal(payload)
- if err != nil {
- logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Failed to marshal corrected data: %v", err)
+func mayContainToolCallPayload(data []byte) bool {
+ // 快速路径:多数 token / 文本事件不包含工具字段,避免进入 JSON 解析热路径。
+ return bytes.Contains(data, []byte(`"tool_calls"`)) ||
+ bytes.Contains(data, []byte(`"function_call"`)) ||
+ bytes.Contains(data, []byte(`"function":{"name"`))
+}
+
+// correctToolCallsArrayAtPath 修正指定路径下 tool_calls 数组中的工具名称。
+func (c *CodexToolCorrector) correctToolCallsArrayAtPath(data []byte, toolCallsPath string) ([]byte, bool) {
+ count := int(gjson.GetBytes(data, toolCallsPath+".#").Int())
+ if count <= 0 {
return data, false
}
-
- return string(correctedBytes), true
-}
-
-// correctToolCallsArray 修正工具调用数组中的工具名称
-func (c *CodexToolCorrector) correctToolCallsArray(toolCalls []any) bool {
+ updated := data
corrected := false
- for _, toolCall := range toolCalls {
- if toolCallMap, ok := toolCall.(map[string]any); ok {
- if function, ok := toolCallMap["function"].(map[string]any); ok {
- if c.correctFunctionCall(function) {
- corrected = true
- }
- }
+ for i := 0; i < count; i++ {
+ functionPath := toolCallsPath + "." + strconv.Itoa(i) + ".function"
+ if next, changed := c.correctFunctionAtPath(updated, functionPath); changed {
+ updated = next
+ corrected = true
}
}
- return corrected
+ return updated, corrected
}
-// correctFunctionCall 修正单个函数调用的工具名称和参数
-func (c *CodexToolCorrector) correctFunctionCall(functionCall map[string]any) bool {
- name, ok := functionCall["name"].(string)
- if !ok || name == "" {
- return false
+// correctFunctionAtPath 修正指定路径下单个函数调用的工具名称和参数。
+func (c *CodexToolCorrector) correctFunctionAtPath(data []byte, functionPath string) ([]byte, bool) {
+ namePath := functionPath + ".name"
+ nameResult := gjson.GetBytes(data, namePath)
+ if !nameResult.Exists() || nameResult.Type != gjson.String {
+ return data, false
}
-
+ name := strings.TrimSpace(nameResult.Str)
+ if name == "" {
+ return data, false
+ }
+ updated := data
corrected := false
// 查找并修正工具名称
if correctName, found := codexToolNameMapping[name]; found {
- functionCall["name"] = correctName
- c.recordCorrection(name, correctName)
- corrected = true
- name = correctName // 使用修正后的名称进行参数修正
+ if next, err := sjson.SetBytes(updated, namePath, correctName); err == nil {
+ updated = next
+ c.recordCorrection(name, correctName)
+ corrected = true
+ name = correctName // 使用修正后的名称进行参数修正
+ }
}
// 修正工具参数(基于工具名称)
- if c.correctToolParameters(name, functionCall) {
+ if next, changed := c.correctToolParametersAtPath(updated, functionPath+".arguments", name); changed {
+ updated = next
corrected = true
}
-
- return corrected
+ return updated, corrected
}
-// correctToolParameters 修正工具参数以符合 OpenCode 规范
-func (c *CodexToolCorrector) correctToolParameters(toolName string, functionCall map[string]any) bool {
- arguments, ok := functionCall["arguments"]
- if !ok {
- return false
+// correctToolParametersAtPath 修正指定路径下 arguments 参数。
+func (c *CodexToolCorrector) correctToolParametersAtPath(data []byte, argumentsPath, toolName string) ([]byte, bool) {
+ if toolName != "bash" && toolName != "edit" {
+ return data, false
}
- // arguments 可能是字符串(JSON)或已解析的 map
- var argsMap map[string]any
- switch v := arguments.(type) {
- case string:
- // 解析 JSON 字符串
- if err := json.Unmarshal([]byte(v), &argsMap); err != nil {
- return false
+ args := gjson.GetBytes(data, argumentsPath)
+ if !args.Exists() {
+ return data, false
+ }
+
+ switch args.Type {
+ case gjson.String:
+ argsJSON := strings.TrimSpace(args.Str)
+ if !gjson.Valid(argsJSON) {
+ return data, false
}
- case map[string]any:
- argsMap = v
+ if !gjson.Parse(argsJSON).IsObject() {
+ return data, false
+ }
+ nextArgsJSON, corrected := c.correctToolArgumentsJSON(argsJSON, toolName)
+ if !corrected {
+ return data, false
+ }
+ next, err := sjson.SetBytes(data, argumentsPath, nextArgsJSON)
+ if err != nil {
+ return data, false
+ }
+ return next, true
+ case gjson.JSON:
+ if !args.IsObject() || !gjson.Valid(args.Raw) {
+ return data, false
+ }
+ nextArgsJSON, corrected := c.correctToolArgumentsJSON(args.Raw, toolName)
+ if !corrected {
+ return data, false
+ }
+ next, err := sjson.SetRawBytes(data, argumentsPath, []byte(nextArgsJSON))
+ if err != nil {
+ return data, false
+ }
+ return next, true
default:
- return false
+ return data, false
+ }
+}
+
+// correctToolArgumentsJSON 修正工具参数 JSON(对象字符串),返回修正后的 JSON 与是否变更。
+func (c *CodexToolCorrector) correctToolArgumentsJSON(argsJSON, toolName string) (string, bool) {
+ if !gjson.Valid(argsJSON) {
+ return argsJSON, false
+ }
+ if !gjson.Parse(argsJSON).IsObject() {
+ return argsJSON, false
}
+ updated := argsJSON
corrected := false
// 根据工具名称应用特定的参数修正规则
switch toolName {
case "bash":
// OpenCode bash 支持 workdir;有些来源会输出 work_dir。
- if _, hasWorkdir := argsMap["workdir"]; !hasWorkdir {
- if workDir, exists := argsMap["work_dir"]; exists {
- argsMap["workdir"] = workDir
- delete(argsMap, "work_dir")
+ if !gjson.Get(updated, "workdir").Exists() {
+ if next, changed := moveJSONField(updated, "work_dir", "workdir"); changed {
+ updated = next
corrected = true
logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'work_dir' to 'workdir' in bash tool")
}
} else {
- if _, exists := argsMap["work_dir"]; exists {
- delete(argsMap, "work_dir")
+ if next, changed := deleteJSONField(updated, "work_dir"); changed {
+ updated = next
corrected = true
logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Removed duplicate 'work_dir' parameter from bash tool")
}
@@ -232,67 +268,71 @@ func (c *CodexToolCorrector) correctToolParameters(toolName string, functionCall
case "edit":
// OpenCode edit 参数为 filePath/oldString/newString(camelCase)。
- if _, exists := argsMap["filePath"]; !exists {
- if filePath, exists := argsMap["file_path"]; exists {
- argsMap["filePath"] = filePath
- delete(argsMap, "file_path")
+ if !gjson.Get(updated, "filePath").Exists() {
+ if next, changed := moveJSONField(updated, "file_path", "filePath"); changed {
+ updated = next
corrected = true
logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'file_path' to 'filePath' in edit tool")
- } else if filePath, exists := argsMap["path"]; exists {
- argsMap["filePath"] = filePath
- delete(argsMap, "path")
+ } else if next, changed := moveJSONField(updated, "path", "filePath"); changed {
+ updated = next
corrected = true
logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'path' to 'filePath' in edit tool")
- } else if filePath, exists := argsMap["file"]; exists {
- argsMap["filePath"] = filePath
- delete(argsMap, "file")
+ } else if next, changed := moveJSONField(updated, "file", "filePath"); changed {
+ updated = next
corrected = true
logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'file' to 'filePath' in edit tool")
}
}
- if _, exists := argsMap["oldString"]; !exists {
- if oldString, exists := argsMap["old_string"]; exists {
- argsMap["oldString"] = oldString
- delete(argsMap, "old_string")
- corrected = true
- logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'old_string' to 'oldString' in edit tool")
- }
+ if next, changed := moveJSONField(updated, "old_string", "oldString"); changed {
+ updated = next
+ corrected = true
+ logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'old_string' to 'oldString' in edit tool")
}
- if _, exists := argsMap["newString"]; !exists {
- if newString, exists := argsMap["new_string"]; exists {
- argsMap["newString"] = newString
- delete(argsMap, "new_string")
- corrected = true
- logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'new_string' to 'newString' in edit tool")
- }
+ if next, changed := moveJSONField(updated, "new_string", "newString"); changed {
+ updated = next
+ corrected = true
+ logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'new_string' to 'newString' in edit tool")
}
- if _, exists := argsMap["replaceAll"]; !exists {
- if replaceAll, exists := argsMap["replace_all"]; exists {
- argsMap["replaceAll"] = replaceAll
- delete(argsMap, "replace_all")
- corrected = true
- logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'replace_all' to 'replaceAll' in edit tool")
- }
+ if next, changed := moveJSONField(updated, "replace_all", "replaceAll"); changed {
+ updated = next
+ corrected = true
+ logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'replace_all' to 'replaceAll' in edit tool")
}
}
+ return updated, corrected
+}
- // 如果修正了参数,需要重新序列化
- if corrected {
- if _, wasString := arguments.(string); wasString {
- // 原本是字符串,序列化回字符串
- if newArgsJSON, err := json.Marshal(argsMap); err == nil {
- functionCall["arguments"] = string(newArgsJSON)
- }
- } else {
- // 原本是 map,直接赋值
- functionCall["arguments"] = argsMap
- }
+func moveJSONField(input, from, to string) (string, bool) {
+ if gjson.Get(input, to).Exists() {
+ return input, false
}
+ src := gjson.Get(input, from)
+ if !src.Exists() {
+ return input, false
+ }
+ next, err := sjson.SetRaw(input, to, src.Raw)
+ if err != nil {
+ return input, false
+ }
+ next, err = sjson.Delete(next, from)
+ if err != nil {
+ return input, false
+ }
+ return next, true
+}
- return corrected
+func deleteJSONField(input, path string) (string, bool) {
+ if !gjson.Get(input, path).Exists() {
+ return input, false
+ }
+ next, err := sjson.Delete(input, path)
+ if err != nil {
+ return input, false
+ }
+ return next, true
}
// recordCorrection 记录一次工具名称修正
diff --git a/backend/internal/service/openai_tool_corrector_test.go b/backend/internal/service/openai_tool_corrector_test.go
index ff518ea6..7c83de9e 100644
--- a/backend/internal/service/openai_tool_corrector_test.go
+++ b/backend/internal/service/openai_tool_corrector_test.go
@@ -5,6 +5,15 @@ import (
"testing"
)
+func TestMayContainToolCallPayload(t *testing.T) {
+ if mayContainToolCallPayload([]byte(`{"type":"response.output_text.delta","delta":"hello"}`)) {
+ t.Fatalf("plain text event should not trigger tool-call parsing")
+ }
+ if !mayContainToolCallPayload([]byte(`{"tool_calls":[{"function":{"name":"apply_patch"}}]}`)) {
+ t.Fatalf("tool_calls event should trigger tool-call parsing")
+ }
+}
+
func TestCorrectToolCallsInSSEData(t *testing.T) {
corrector := NewCodexToolCorrector()
diff --git a/backend/internal/service/openai_ws_account_sticky_test.go b/backend/internal/service/openai_ws_account_sticky_test.go
new file mode 100644
index 00000000..3fe08179
--- /dev/null
+++ b/backend/internal/service/openai_ws_account_sticky_test.go
@@ -0,0 +1,190 @@
+package service
+
+import (
+ "context"
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/stretchr/testify/require"
+)
+
+func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_Hit(t *testing.T) {
+ ctx := context.Background()
+ groupID := int64(23)
+ account := Account{
+ ID: 2,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 2,
+ Extra: map[string]any{
+ "openai_apikey_responses_websockets_v2_enabled": true,
+ },
+ }
+ cache := &stubGatewayCache{}
+ store := NewOpenAIWSStateStore(cache)
+ cfg := newOpenAIWSV2TestConfig()
+
+ svc := &OpenAIGatewayService{
+ accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
+ cache: cache,
+ cfg: cfg,
+ concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
+ openaiWSStateStore: store,
+ }
+
+ require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_1", account.ID, time.Hour))
+
+ selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_1", "gpt-5.1", nil)
+ require.NoError(t, err)
+ require.NotNil(t, selection)
+ require.NotNil(t, selection.Account)
+ require.Equal(t, account.ID, selection.Account.ID)
+ require.True(t, selection.Acquired)
+ if selection.ReleaseFunc != nil {
+ selection.ReleaseFunc()
+ }
+}
+
+func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_Excluded(t *testing.T) {
+ ctx := context.Background()
+ groupID := int64(23)
+ account := Account{
+ ID: 8,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Extra: map[string]any{
+ "openai_apikey_responses_websockets_v2_enabled": true,
+ },
+ }
+ cache := &stubGatewayCache{}
+ store := NewOpenAIWSStateStore(cache)
+ cfg := newOpenAIWSV2TestConfig()
+ svc := &OpenAIGatewayService{
+ accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
+ cache: cache,
+ cfg: cfg,
+ concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
+ openaiWSStateStore: store,
+ }
+
+ require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_2", account.ID, time.Hour))
+
+ selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_2", "gpt-5.1", map[int64]struct{}{account.ID: {}})
+ require.NoError(t, err)
+ require.Nil(t, selection)
+}
+
+func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_ForceHTTPIgnored(t *testing.T) {
+ ctx := context.Background()
+ groupID := int64(23)
+ account := Account{
+ ID: 11,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Extra: map[string]any{
+ "openai_ws_force_http": true,
+ "responses_websockets_v2_enabled": true,
+ },
+ }
+ cache := &stubGatewayCache{}
+ store := NewOpenAIWSStateStore(cache)
+ cfg := newOpenAIWSV2TestConfig()
+ svc := &OpenAIGatewayService{
+ accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
+ cache: cache,
+ cfg: cfg,
+ concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
+ openaiWSStateStore: store,
+ }
+
+ require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_force_http", account.ID, time.Hour))
+
+ selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_force_http", "gpt-5.1", nil)
+ require.NoError(t, err)
+ require.Nil(t, selection, "force_http 场景应忽略 previous_response_id 粘连")
+}
+
+func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_BusyKeepsSticky(t *testing.T) {
+ ctx := context.Background()
+ groupID := int64(23)
+ accounts := []Account{
+ {
+ ID: 21,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Priority: 0,
+ Extra: map[string]any{
+ "openai_apikey_responses_websockets_v2_enabled": true,
+ },
+ },
+ {
+ ID: 22,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Priority: 9,
+ Extra: map[string]any{
+ "openai_apikey_responses_websockets_v2_enabled": true,
+ },
+ },
+ }
+
+ cache := &stubGatewayCache{}
+ store := NewOpenAIWSStateStore(cache)
+ cfg := newOpenAIWSV2TestConfig()
+ cfg.Gateway.Scheduling.StickySessionMaxWaiting = 2
+ cfg.Gateway.Scheduling.StickySessionWaitTimeout = 30 * time.Second
+
+ concurrencyCache := stubConcurrencyCache{
+ acquireResults: map[int64]bool{
+ 21: false, // previous_response 命中的账号繁忙
+ 22: true, // 次优账号可用(若回退会命中)
+ },
+ waitCounts: map[int64]int{
+ 21: 999,
+ },
+ }
+
+ svc := &OpenAIGatewayService{
+ accountRepo: stubOpenAIAccountRepo{accounts: accounts},
+ cache: cache,
+ cfg: cfg,
+ concurrencyService: NewConcurrencyService(concurrencyCache),
+ openaiWSStateStore: store,
+ }
+
+ require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_busy", 21, time.Hour))
+
+ selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_busy", "gpt-5.1", nil)
+ require.NoError(t, err)
+ require.NotNil(t, selection)
+ require.NotNil(t, selection.Account)
+ require.Equal(t, int64(21), selection.Account.ID, "busy previous_response sticky account should remain selected")
+ require.False(t, selection.Acquired)
+ require.NotNil(t, selection.WaitPlan)
+ require.Equal(t, int64(21), selection.WaitPlan.AccountID)
+}
+
+func newOpenAIWSV2TestConfig() *config.Config {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.Enabled = true
+ cfg.Gateway.OpenAIWS.OAuthEnabled = true
+ cfg.Gateway.OpenAIWS.APIKeyEnabled = true
+ cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
+ cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600
+ return cfg
+}
diff --git a/backend/internal/service/openai_ws_client.go b/backend/internal/service/openai_ws_client.go
new file mode 100644
index 00000000..9f3c47b7
--- /dev/null
+++ b/backend/internal/service/openai_ws_client.go
@@ -0,0 +1,285 @@
+package service
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "net/http"
+ "net/url"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ coderws "github.com/coder/websocket"
+ "github.com/coder/websocket/wsjson"
+)
+
+const openAIWSMessageReadLimitBytes int64 = 16 * 1024 * 1024
+const (
+ openAIWSProxyTransportMaxIdleConns = 128
+ openAIWSProxyTransportMaxIdleConnsPerHost = 64
+ openAIWSProxyTransportIdleConnTimeout = 90 * time.Second
+ openAIWSProxyClientCacheMaxEntries = 256
+ openAIWSProxyClientCacheIdleTTL = 15 * time.Minute
+)
+
+type OpenAIWSTransportMetricsSnapshot struct {
+ ProxyClientCacheHits int64 `json:"proxy_client_cache_hits"`
+ ProxyClientCacheMisses int64 `json:"proxy_client_cache_misses"`
+ TransportReuseRatio float64 `json:"transport_reuse_ratio"`
+}
+
+// openAIWSClientConn 抽象 WS 客户端连接,便于替换底层实现。
+type openAIWSClientConn interface {
+ WriteJSON(ctx context.Context, value any) error
+ ReadMessage(ctx context.Context) ([]byte, error)
+ Ping(ctx context.Context) error
+ Close() error
+}
+
+// openAIWSClientDialer 抽象 WS 建连器。
+type openAIWSClientDialer interface {
+ Dial(ctx context.Context, wsURL string, headers http.Header, proxyURL string) (openAIWSClientConn, int, http.Header, error)
+}
+
+type openAIWSTransportMetricsDialer interface {
+ SnapshotTransportMetrics() OpenAIWSTransportMetricsSnapshot
+}
+
+func newDefaultOpenAIWSClientDialer() openAIWSClientDialer {
+ return &coderOpenAIWSClientDialer{
+ proxyClients: make(map[string]*openAIWSProxyClientEntry),
+ }
+}
+
+type coderOpenAIWSClientDialer struct {
+ proxyMu sync.Mutex
+ proxyClients map[string]*openAIWSProxyClientEntry
+ proxyHits atomic.Int64
+ proxyMisses atomic.Int64
+}
+
+type openAIWSProxyClientEntry struct {
+ client *http.Client
+ lastUsedUnixNano int64
+}
+
+func (d *coderOpenAIWSClientDialer) Dial(
+ ctx context.Context,
+ wsURL string,
+ headers http.Header,
+ proxyURL string,
+) (openAIWSClientConn, int, http.Header, error) {
+ targetURL := strings.TrimSpace(wsURL)
+ if targetURL == "" {
+ return nil, 0, nil, errors.New("ws url is empty")
+ }
+
+ opts := &coderws.DialOptions{
+ HTTPHeader: cloneHeader(headers),
+ CompressionMode: coderws.CompressionContextTakeover,
+ }
+ if proxy := strings.TrimSpace(proxyURL); proxy != "" {
+ proxyClient, err := d.proxyHTTPClient(proxy)
+ if err != nil {
+ return nil, 0, nil, err
+ }
+ opts.HTTPClient = proxyClient
+ }
+
+ conn, resp, err := coderws.Dial(ctx, targetURL, opts)
+ if err != nil {
+ status := 0
+ respHeaders := http.Header(nil)
+ if resp != nil {
+ status = resp.StatusCode
+ respHeaders = cloneHeader(resp.Header)
+ }
+ return nil, status, respHeaders, err
+ }
+ // coder/websocket 默认单消息读取上限为 32KB,Codex WS 事件(如 rate_limits/大 delta)
+ // 可能超过该阈值,需显式提高上限,避免本地 read_fail(message too big)。
+ conn.SetReadLimit(openAIWSMessageReadLimitBytes)
+ respHeaders := http.Header(nil)
+ if resp != nil {
+ respHeaders = cloneHeader(resp.Header)
+ }
+ return &coderOpenAIWSClientConn{conn: conn}, 0, respHeaders, nil
+}
+
+func (d *coderOpenAIWSClientDialer) proxyHTTPClient(proxy string) (*http.Client, error) {
+ if d == nil {
+ return nil, errors.New("openai ws dialer is nil")
+ }
+ normalizedProxy := strings.TrimSpace(proxy)
+ if normalizedProxy == "" {
+ return nil, errors.New("proxy url is empty")
+ }
+ parsedProxyURL, err := url.Parse(normalizedProxy)
+ if err != nil {
+ return nil, fmt.Errorf("invalid proxy url: %w", err)
+ }
+ now := time.Now().UnixNano()
+
+ d.proxyMu.Lock()
+ defer d.proxyMu.Unlock()
+ if entry, ok := d.proxyClients[normalizedProxy]; ok && entry != nil && entry.client != nil {
+ entry.lastUsedUnixNano = now
+ d.proxyHits.Add(1)
+ return entry.client, nil
+ }
+ d.cleanupProxyClientsLocked(now)
+ transport := &http.Transport{
+ Proxy: http.ProxyURL(parsedProxyURL),
+ MaxIdleConns: openAIWSProxyTransportMaxIdleConns,
+ MaxIdleConnsPerHost: openAIWSProxyTransportMaxIdleConnsPerHost,
+ IdleConnTimeout: openAIWSProxyTransportIdleConnTimeout,
+ TLSHandshakeTimeout: 10 * time.Second,
+ ForceAttemptHTTP2: true,
+ }
+ client := &http.Client{Transport: transport}
+ d.proxyClients[normalizedProxy] = &openAIWSProxyClientEntry{
+ client: client,
+ lastUsedUnixNano: now,
+ }
+ d.ensureProxyClientCapacityLocked()
+ d.proxyMisses.Add(1)
+ return client, nil
+}
+
+func (d *coderOpenAIWSClientDialer) cleanupProxyClientsLocked(nowUnixNano int64) {
+ if d == nil || len(d.proxyClients) == 0 {
+ return
+ }
+ idleTTL := openAIWSProxyClientCacheIdleTTL
+ if idleTTL <= 0 {
+ return
+ }
+ now := time.Unix(0, nowUnixNano)
+ for key, entry := range d.proxyClients {
+ if entry == nil || entry.client == nil {
+ delete(d.proxyClients, key)
+ continue
+ }
+ lastUsed := time.Unix(0, entry.lastUsedUnixNano)
+ if now.Sub(lastUsed) > idleTTL {
+ closeOpenAIWSProxyClient(entry.client)
+ delete(d.proxyClients, key)
+ }
+ }
+}
+
+func (d *coderOpenAIWSClientDialer) ensureProxyClientCapacityLocked() {
+ if d == nil {
+ return
+ }
+ maxEntries := openAIWSProxyClientCacheMaxEntries
+ if maxEntries <= 0 {
+ return
+ }
+ for len(d.proxyClients) > maxEntries {
+ var oldestKey string
+ var oldestLastUsed int64
+ hasOldest := false
+ for key, entry := range d.proxyClients {
+ lastUsed := int64(0)
+ if entry != nil {
+ lastUsed = entry.lastUsedUnixNano
+ }
+ if !hasOldest || lastUsed < oldestLastUsed {
+ hasOldest = true
+ oldestKey = key
+ oldestLastUsed = lastUsed
+ }
+ }
+ if !hasOldest {
+ return
+ }
+ if entry := d.proxyClients[oldestKey]; entry != nil {
+ closeOpenAIWSProxyClient(entry.client)
+ }
+ delete(d.proxyClients, oldestKey)
+ }
+}
+
+func closeOpenAIWSProxyClient(client *http.Client) {
+ if client == nil || client.Transport == nil {
+ return
+ }
+ if transport, ok := client.Transport.(*http.Transport); ok && transport != nil {
+ transport.CloseIdleConnections()
+ }
+}
+
+func (d *coderOpenAIWSClientDialer) SnapshotTransportMetrics() OpenAIWSTransportMetricsSnapshot {
+ if d == nil {
+ return OpenAIWSTransportMetricsSnapshot{}
+ }
+ hits := d.proxyHits.Load()
+ misses := d.proxyMisses.Load()
+ total := hits + misses
+ reuseRatio := 0.0
+ if total > 0 {
+ reuseRatio = float64(hits) / float64(total)
+ }
+ return OpenAIWSTransportMetricsSnapshot{
+ ProxyClientCacheHits: hits,
+ ProxyClientCacheMisses: misses,
+ TransportReuseRatio: reuseRatio,
+ }
+}
+
+type coderOpenAIWSClientConn struct {
+ conn *coderws.Conn
+}
+
+func (c *coderOpenAIWSClientConn) WriteJSON(ctx context.Context, value any) error {
+ if c == nil || c.conn == nil {
+ return errOpenAIWSConnClosed
+ }
+ if ctx == nil {
+ ctx = context.Background()
+ }
+ return wsjson.Write(ctx, c.conn, value)
+}
+
+func (c *coderOpenAIWSClientConn) ReadMessage(ctx context.Context) ([]byte, error) {
+ if c == nil || c.conn == nil {
+ return nil, errOpenAIWSConnClosed
+ }
+ if ctx == nil {
+ ctx = context.Background()
+ }
+
+ msgType, payload, err := c.conn.Read(ctx)
+ if err != nil {
+ return nil, err
+ }
+ switch msgType {
+ case coderws.MessageText, coderws.MessageBinary:
+ return payload, nil
+ default:
+ return nil, errOpenAIWSConnClosed
+ }
+}
+
+func (c *coderOpenAIWSClientConn) Ping(ctx context.Context) error {
+ if c == nil || c.conn == nil {
+ return errOpenAIWSConnClosed
+ }
+ if ctx == nil {
+ ctx = context.Background()
+ }
+ return c.conn.Ping(ctx)
+}
+
+func (c *coderOpenAIWSClientConn) Close() error {
+ if c == nil || c.conn == nil {
+ return nil
+ }
+ // Close 为幂等,忽略重复关闭错误。
+ _ = c.conn.Close(coderws.StatusNormalClosure, "")
+ _ = c.conn.CloseNow()
+ return nil
+}
diff --git a/backend/internal/service/openai_ws_client_test.go b/backend/internal/service/openai_ws_client_test.go
new file mode 100644
index 00000000..a88d6266
--- /dev/null
+++ b/backend/internal/service/openai_ws_client_test.go
@@ -0,0 +1,112 @@
+package service
+
+import (
+ "fmt"
+ "net/http"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestCoderOpenAIWSClientDialer_ProxyHTTPClientReuse(t *testing.T) {
+ dialer := newDefaultOpenAIWSClientDialer()
+ impl, ok := dialer.(*coderOpenAIWSClientDialer)
+ require.True(t, ok)
+
+ c1, err := impl.proxyHTTPClient("http://127.0.0.1:8080")
+ require.NoError(t, err)
+ c2, err := impl.proxyHTTPClient("http://127.0.0.1:8080")
+ require.NoError(t, err)
+ require.Same(t, c1, c2, "同一代理地址应复用同一个 HTTP 客户端")
+
+ c3, err := impl.proxyHTTPClient("http://127.0.0.1:8081")
+ require.NoError(t, err)
+ require.NotSame(t, c1, c3, "不同代理地址应分离客户端")
+}
+
+func TestCoderOpenAIWSClientDialer_ProxyHTTPClientInvalidURL(t *testing.T) {
+ dialer := newDefaultOpenAIWSClientDialer()
+ impl, ok := dialer.(*coderOpenAIWSClientDialer)
+ require.True(t, ok)
+
+ _, err := impl.proxyHTTPClient("://bad")
+ require.Error(t, err)
+}
+
+func TestCoderOpenAIWSClientDialer_TransportMetricsSnapshot(t *testing.T) {
+ dialer := newDefaultOpenAIWSClientDialer()
+ impl, ok := dialer.(*coderOpenAIWSClientDialer)
+ require.True(t, ok)
+
+ _, err := impl.proxyHTTPClient("http://127.0.0.1:18080")
+ require.NoError(t, err)
+ _, err = impl.proxyHTTPClient("http://127.0.0.1:18080")
+ require.NoError(t, err)
+ _, err = impl.proxyHTTPClient("http://127.0.0.1:18081")
+ require.NoError(t, err)
+
+ snapshot := impl.SnapshotTransportMetrics()
+ require.Equal(t, int64(1), snapshot.ProxyClientCacheHits)
+ require.Equal(t, int64(2), snapshot.ProxyClientCacheMisses)
+ require.InDelta(t, 1.0/3.0, snapshot.TransportReuseRatio, 0.0001)
+}
+
+func TestCoderOpenAIWSClientDialer_ProxyClientCacheCapacity(t *testing.T) {
+ dialer := newDefaultOpenAIWSClientDialer()
+ impl, ok := dialer.(*coderOpenAIWSClientDialer)
+ require.True(t, ok)
+
+ total := openAIWSProxyClientCacheMaxEntries + 32
+ for i := 0; i < total; i++ {
+ _, err := impl.proxyHTTPClient(fmt.Sprintf("http://127.0.0.1:%d", 20000+i))
+ require.NoError(t, err)
+ }
+
+ impl.proxyMu.Lock()
+ cacheSize := len(impl.proxyClients)
+ impl.proxyMu.Unlock()
+
+ require.LessOrEqual(t, cacheSize, openAIWSProxyClientCacheMaxEntries, "代理客户端缓存应受容量上限约束")
+}
+
+func TestCoderOpenAIWSClientDialer_ProxyClientCacheIdleTTL(t *testing.T) {
+ dialer := newDefaultOpenAIWSClientDialer()
+ impl, ok := dialer.(*coderOpenAIWSClientDialer)
+ require.True(t, ok)
+
+ oldProxy := "http://127.0.0.1:28080"
+ _, err := impl.proxyHTTPClient(oldProxy)
+ require.NoError(t, err)
+
+ impl.proxyMu.Lock()
+ oldEntry := impl.proxyClients[oldProxy]
+ require.NotNil(t, oldEntry)
+ oldEntry.lastUsedUnixNano = time.Now().Add(-openAIWSProxyClientCacheIdleTTL - time.Minute).UnixNano()
+ impl.proxyMu.Unlock()
+
+ // 触发一次新的代理获取,驱动 TTL 清理。
+ _, err = impl.proxyHTTPClient("http://127.0.0.1:28081")
+ require.NoError(t, err)
+
+ impl.proxyMu.Lock()
+ _, exists := impl.proxyClients[oldProxy]
+ impl.proxyMu.Unlock()
+
+ require.False(t, exists, "超过空闲 TTL 的代理客户端应被回收")
+}
+
+func TestCoderOpenAIWSClientDialer_ProxyTransportTLSHandshakeTimeout(t *testing.T) {
+ dialer := newDefaultOpenAIWSClientDialer()
+ impl, ok := dialer.(*coderOpenAIWSClientDialer)
+ require.True(t, ok)
+
+ client, err := impl.proxyHTTPClient("http://127.0.0.1:38080")
+ require.NoError(t, err)
+ require.NotNil(t, client)
+
+ transport, ok := client.Transport.(*http.Transport)
+ require.True(t, ok)
+ require.NotNil(t, transport)
+ require.Equal(t, 10*time.Second, transport.TLSHandshakeTimeout)
+}
diff --git a/backend/internal/service/openai_ws_fallback_test.go b/backend/internal/service/openai_ws_fallback_test.go
new file mode 100644
index 00000000..ce06f6a2
--- /dev/null
+++ b/backend/internal/service/openai_ws_fallback_test.go
@@ -0,0 +1,251 @@
+package service
+
+import (
+ "context"
+ "errors"
+ "net/http"
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ coderws "github.com/coder/websocket"
+ "github.com/stretchr/testify/require"
+)
+
+func TestClassifyOpenAIWSAcquireError(t *testing.T) {
+ t.Run("dial_426_upgrade_required", func(t *testing.T) {
+ err := &openAIWSDialError{StatusCode: 426, Err: errors.New("upgrade required")}
+ require.Equal(t, "upgrade_required", classifyOpenAIWSAcquireError(err))
+ })
+
+ t.Run("queue_full", func(t *testing.T) {
+ require.Equal(t, "conn_queue_full", classifyOpenAIWSAcquireError(errOpenAIWSConnQueueFull))
+ })
+
+ t.Run("preferred_conn_unavailable", func(t *testing.T) {
+ require.Equal(t, "preferred_conn_unavailable", classifyOpenAIWSAcquireError(errOpenAIWSPreferredConnUnavailable))
+ })
+
+ t.Run("acquire_timeout", func(t *testing.T) {
+ require.Equal(t, "acquire_timeout", classifyOpenAIWSAcquireError(context.DeadlineExceeded))
+ })
+
+ t.Run("auth_failed_401", func(t *testing.T) {
+ err := &openAIWSDialError{StatusCode: 401, Err: errors.New("unauthorized")}
+ require.Equal(t, "auth_failed", classifyOpenAIWSAcquireError(err))
+ })
+
+ t.Run("upstream_rate_limited", func(t *testing.T) {
+ err := &openAIWSDialError{StatusCode: 429, Err: errors.New("rate limited")}
+ require.Equal(t, "upstream_rate_limited", classifyOpenAIWSAcquireError(err))
+ })
+
+ t.Run("upstream_5xx", func(t *testing.T) {
+ err := &openAIWSDialError{StatusCode: 502, Err: errors.New("bad gateway")}
+ require.Equal(t, "upstream_5xx", classifyOpenAIWSAcquireError(err))
+ })
+
+ t.Run("dial_failed_other_status", func(t *testing.T) {
+ err := &openAIWSDialError{StatusCode: 418, Err: errors.New("teapot")}
+ require.Equal(t, "dial_failed", classifyOpenAIWSAcquireError(err))
+ })
+
+ t.Run("other", func(t *testing.T) {
+ require.Equal(t, "acquire_conn", classifyOpenAIWSAcquireError(errors.New("x")))
+ })
+
+ t.Run("nil", func(t *testing.T) {
+ require.Equal(t, "acquire_conn", classifyOpenAIWSAcquireError(nil))
+ })
+}
+
+func TestClassifyOpenAIWSDialError(t *testing.T) {
+ t.Run("handshake_not_finished", func(t *testing.T) {
+ err := &openAIWSDialError{
+ StatusCode: http.StatusBadGateway,
+ Err: errors.New("WebSocket protocol error: Handshake not finished"),
+ }
+ require.Equal(t, "handshake_not_finished", classifyOpenAIWSDialError(err))
+ })
+
+ t.Run("context_deadline", func(t *testing.T) {
+ err := &openAIWSDialError{
+ StatusCode: 0,
+ Err: context.DeadlineExceeded,
+ }
+ require.Equal(t, "ctx_deadline_exceeded", classifyOpenAIWSDialError(err))
+ })
+}
+
+func TestSummarizeOpenAIWSDialError(t *testing.T) {
+ err := &openAIWSDialError{
+ StatusCode: http.StatusBadGateway,
+ ResponseHeaders: http.Header{
+ "Server": []string{"cloudflare"},
+ "Via": []string{"1.1 example"},
+ "Cf-Ray": []string{"abcd1234"},
+ "X-Request-Id": []string{"req_123"},
+ },
+ Err: errors.New("WebSocket protocol error: Handshake not finished"),
+ }
+
+ status, class, closeStatus, closeReason, server, via, cfRay, reqID := summarizeOpenAIWSDialError(err)
+ require.Equal(t, http.StatusBadGateway, status)
+ require.Equal(t, "handshake_not_finished", class)
+ require.Equal(t, "-", closeStatus)
+ require.Equal(t, "-", closeReason)
+ require.Equal(t, "cloudflare", server)
+ require.Equal(t, "1.1 example", via)
+ require.Equal(t, "abcd1234", cfRay)
+ require.Equal(t, "req_123", reqID)
+}
+
+func TestClassifyOpenAIWSErrorEvent(t *testing.T) {
+ reason, recoverable := classifyOpenAIWSErrorEvent([]byte(`{"type":"error","error":{"code":"upgrade_required","message":"Upgrade required"}}`))
+ require.Equal(t, "upgrade_required", reason)
+ require.True(t, recoverable)
+
+ reason, recoverable = classifyOpenAIWSErrorEvent([]byte(`{"type":"error","error":{"code":"previous_response_not_found","message":"not found"}}`))
+ require.Equal(t, "previous_response_not_found", reason)
+ require.True(t, recoverable)
+}
+
+func TestClassifyOpenAIWSReconnectReason(t *testing.T) {
+ reason, retryable := classifyOpenAIWSReconnectReason(wrapOpenAIWSFallback("policy_violation", errors.New("policy")))
+ require.Equal(t, "policy_violation", reason)
+ require.False(t, retryable)
+
+ reason, retryable = classifyOpenAIWSReconnectReason(wrapOpenAIWSFallback("read_event", errors.New("io")))
+ require.Equal(t, "read_event", reason)
+ require.True(t, retryable)
+}
+
+func TestOpenAIWSErrorHTTPStatus(t *testing.T) {
+ require.Equal(t, http.StatusBadRequest, openAIWSErrorHTTPStatus([]byte(`{"type":"error","error":{"type":"invalid_request_error","code":"invalid_request","message":"invalid input"}}`)))
+ require.Equal(t, http.StatusUnauthorized, openAIWSErrorHTTPStatus([]byte(`{"type":"error","error":{"type":"authentication_error","code":"invalid_api_key","message":"auth failed"}}`)))
+ require.Equal(t, http.StatusForbidden, openAIWSErrorHTTPStatus([]byte(`{"type":"error","error":{"type":"permission_error","code":"forbidden","message":"forbidden"}}`)))
+ require.Equal(t, http.StatusTooManyRequests, openAIWSErrorHTTPStatus([]byte(`{"type":"error","error":{"type":"rate_limit_error","code":"rate_limit_exceeded","message":"rate limited"}}`)))
+ require.Equal(t, http.StatusBadGateway, openAIWSErrorHTTPStatus([]byte(`{"type":"error","error":{"type":"server_error","code":"server_error","message":"server"}}`)))
+}
+
+func TestResolveOpenAIWSFallbackErrorResponse(t *testing.T) {
+ t.Run("previous_response_not_found", func(t *testing.T) {
+ statusCode, errType, clientMessage, upstreamMessage, ok := resolveOpenAIWSFallbackErrorResponse(
+ wrapOpenAIWSFallback("previous_response_not_found", errors.New("previous response not found")),
+ )
+ require.True(t, ok)
+ require.Equal(t, http.StatusBadRequest, statusCode)
+ require.Equal(t, "invalid_request_error", errType)
+ require.Equal(t, "previous response not found", clientMessage)
+ require.Equal(t, "previous response not found", upstreamMessage)
+ })
+
+ t.Run("auth_failed_uses_dial_status", func(t *testing.T) {
+ statusCode, errType, clientMessage, upstreamMessage, ok := resolveOpenAIWSFallbackErrorResponse(
+ wrapOpenAIWSFallback("auth_failed", &openAIWSDialError{
+ StatusCode: http.StatusForbidden,
+ Err: errors.New("forbidden"),
+ }),
+ )
+ require.True(t, ok)
+ require.Equal(t, http.StatusForbidden, statusCode)
+ require.Equal(t, "upstream_error", errType)
+ require.Equal(t, "forbidden", clientMessage)
+ require.Equal(t, "forbidden", upstreamMessage)
+ })
+
+ t.Run("non_fallback_error_not_resolved", func(t *testing.T) {
+ _, _, _, _, ok := resolveOpenAIWSFallbackErrorResponse(errors.New("plain error"))
+ require.False(t, ok)
+ })
+}
+
+func TestOpenAIWSFallbackCooling(t *testing.T) {
+ svc := &OpenAIGatewayService{cfg: &config.Config{}}
+ svc.cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1
+
+ require.False(t, svc.isOpenAIWSFallbackCooling(1))
+ svc.markOpenAIWSFallbackCooling(1, "upgrade_required")
+ require.True(t, svc.isOpenAIWSFallbackCooling(1))
+
+ svc.clearOpenAIWSFallbackCooling(1)
+ require.False(t, svc.isOpenAIWSFallbackCooling(1))
+
+ svc.markOpenAIWSFallbackCooling(2, "x")
+ time.Sleep(1200 * time.Millisecond)
+ require.False(t, svc.isOpenAIWSFallbackCooling(2))
+}
+
+func TestOpenAIWSRetryBackoff(t *testing.T) {
+ svc := &OpenAIGatewayService{cfg: &config.Config{}}
+ svc.cfg.Gateway.OpenAIWS.RetryBackoffInitialMS = 100
+ svc.cfg.Gateway.OpenAIWS.RetryBackoffMaxMS = 400
+ svc.cfg.Gateway.OpenAIWS.RetryJitterRatio = 0
+
+ require.Equal(t, time.Duration(100)*time.Millisecond, svc.openAIWSRetryBackoff(1))
+ require.Equal(t, time.Duration(200)*time.Millisecond, svc.openAIWSRetryBackoff(2))
+ require.Equal(t, time.Duration(400)*time.Millisecond, svc.openAIWSRetryBackoff(3))
+ require.Equal(t, time.Duration(400)*time.Millisecond, svc.openAIWSRetryBackoff(4))
+}
+
+func TestOpenAIWSRetryTotalBudget(t *testing.T) {
+ svc := &OpenAIGatewayService{cfg: &config.Config{}}
+ svc.cfg.Gateway.OpenAIWS.RetryTotalBudgetMS = 1200
+ require.Equal(t, 1200*time.Millisecond, svc.openAIWSRetryTotalBudget())
+
+ svc.cfg.Gateway.OpenAIWS.RetryTotalBudgetMS = 0
+ require.Equal(t, time.Duration(0), svc.openAIWSRetryTotalBudget())
+}
+
+func TestClassifyOpenAIWSReadFallbackReason(t *testing.T) {
+ require.Equal(t, "policy_violation", classifyOpenAIWSReadFallbackReason(coderws.CloseError{Code: coderws.StatusPolicyViolation}))
+ require.Equal(t, "message_too_big", classifyOpenAIWSReadFallbackReason(coderws.CloseError{Code: coderws.StatusMessageTooBig}))
+ require.Equal(t, "read_event", classifyOpenAIWSReadFallbackReason(errors.New("io")))
+}
+
+func TestOpenAIWSStoreDisabledConnMode(t *testing.T) {
+ svc := &OpenAIGatewayService{cfg: &config.Config{}}
+ svc.cfg.Gateway.OpenAIWS.StoreDisabledForceNewConn = true
+ require.Equal(t, openAIWSStoreDisabledConnModeStrict, svc.openAIWSStoreDisabledConnMode())
+
+ svc.cfg.Gateway.OpenAIWS.StoreDisabledConnMode = "adaptive"
+ require.Equal(t, openAIWSStoreDisabledConnModeAdaptive, svc.openAIWSStoreDisabledConnMode())
+
+ svc.cfg.Gateway.OpenAIWS.StoreDisabledConnMode = ""
+ svc.cfg.Gateway.OpenAIWS.StoreDisabledForceNewConn = false
+ require.Equal(t, openAIWSStoreDisabledConnModeOff, svc.openAIWSStoreDisabledConnMode())
+}
+
+func TestShouldForceNewConnOnStoreDisabled(t *testing.T) {
+ require.True(t, shouldForceNewConnOnStoreDisabled(openAIWSStoreDisabledConnModeStrict, ""))
+ require.False(t, shouldForceNewConnOnStoreDisabled(openAIWSStoreDisabledConnModeOff, "policy_violation"))
+
+ require.True(t, shouldForceNewConnOnStoreDisabled(openAIWSStoreDisabledConnModeAdaptive, "policy_violation"))
+ require.True(t, shouldForceNewConnOnStoreDisabled(openAIWSStoreDisabledConnModeAdaptive, "prewarm_message_too_big"))
+ require.False(t, shouldForceNewConnOnStoreDisabled(openAIWSStoreDisabledConnModeAdaptive, "read_event"))
+}
+
+func TestOpenAIWSRetryMetricsSnapshot(t *testing.T) {
+ svc := &OpenAIGatewayService{}
+ svc.recordOpenAIWSRetryAttempt(150 * time.Millisecond)
+ svc.recordOpenAIWSRetryAttempt(0)
+ svc.recordOpenAIWSRetryExhausted()
+ svc.recordOpenAIWSNonRetryableFastFallback()
+
+ snapshot := svc.SnapshotOpenAIWSRetryMetrics()
+ require.Equal(t, int64(2), snapshot.RetryAttemptsTotal)
+ require.Equal(t, int64(150), snapshot.RetryBackoffMsTotal)
+ require.Equal(t, int64(1), snapshot.RetryExhaustedTotal)
+ require.Equal(t, int64(1), snapshot.NonRetryableFastFallbackTotal)
+}
+
+func TestShouldLogOpenAIWSPayloadSchema(t *testing.T) {
+ svc := &OpenAIGatewayService{cfg: &config.Config{}}
+
+ svc.cfg.Gateway.OpenAIWS.PayloadLogSampleRate = 0
+ require.True(t, svc.shouldLogOpenAIWSPayloadSchema(1), "首次尝试应始终记录 payload_schema")
+ require.False(t, svc.shouldLogOpenAIWSPayloadSchema(2))
+
+ svc.cfg.Gateway.OpenAIWS.PayloadLogSampleRate = 1
+ require.True(t, svc.shouldLogOpenAIWSPayloadSchema(2))
+}
diff --git a/backend/internal/service/openai_ws_forwarder.go b/backend/internal/service/openai_ws_forwarder.go
new file mode 100644
index 00000000..74ba472f
--- /dev/null
+++ b/backend/internal/service/openai_ws_forwarder.go
@@ -0,0 +1,3955 @@
+package service
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "math/rand"
+ "net"
+ "net/http"
+ "net/url"
+ "sort"
+ "strings"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/pkg/logger"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
+ "github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
+ coderws "github.com/coder/websocket"
+ "github.com/gin-gonic/gin"
+ "github.com/tidwall/gjson"
+ "github.com/tidwall/sjson"
+ "go.uber.org/zap"
+)
+
+const (
+ openAIWSBetaV1Value = "responses_websockets=2026-02-04"
+ openAIWSBetaV2Value = "responses_websockets=2026-02-06"
+
+ openAIWSTurnStateHeader = "x-codex-turn-state"
+ openAIWSTurnMetadataHeader = "x-codex-turn-metadata"
+
+ openAIWSLogValueMaxLen = 160
+ openAIWSHeaderValueMaxLen = 120
+ openAIWSIDValueMaxLen = 64
+ openAIWSEventLogHeadLimit = 20
+ openAIWSEventLogEveryN = 50
+ openAIWSBufferLogHeadLimit = 8
+ openAIWSBufferLogEveryN = 20
+ openAIWSPrewarmEventLogHead = 10
+ openAIWSPayloadKeySizeTopN = 6
+
+ openAIWSPayloadSizeEstimateDepth = 3
+ openAIWSPayloadSizeEstimateMaxBytes = 64 * 1024
+ openAIWSPayloadSizeEstimateMaxItems = 16
+
+ openAIWSEventFlushBatchSizeDefault = 4
+ openAIWSEventFlushIntervalDefault = 25 * time.Millisecond
+ openAIWSPayloadLogSampleDefault = 0.2
+
+ openAIWSStoreDisabledConnModeStrict = "strict"
+ openAIWSStoreDisabledConnModeAdaptive = "adaptive"
+ openAIWSStoreDisabledConnModeOff = "off"
+
+ openAIWSIngressStagePreviousResponseNotFound = "previous_response_not_found"
+ openAIWSMaxPrevResponseIDDeletePasses = 8
+)
+
+var openAIWSLogValueReplacer = strings.NewReplacer(
+ "error", "err",
+ "fallback", "fb",
+ "warning", "warnx",
+ "failed", "fail",
+)
+
+var openAIWSIngressPreflightPingIdle = 20 * time.Second
+
+// openAIWSFallbackError 表示可安全回退到 HTTP 的 WS 错误(尚未写下游)。
+type openAIWSFallbackError struct {
+ Reason string
+ Err error
+}
+
+func (e *openAIWSFallbackError) Error() string {
+ if e == nil {
+ return ""
+ }
+ if e.Err == nil {
+ return fmt.Sprintf("openai ws fallback: %s", strings.TrimSpace(e.Reason))
+ }
+ return fmt.Sprintf("openai ws fallback: %s: %v", strings.TrimSpace(e.Reason), e.Err)
+}
+
+func (e *openAIWSFallbackError) Unwrap() error {
+ if e == nil {
+ return nil
+ }
+ return e.Err
+}
+
+func wrapOpenAIWSFallback(reason string, err error) error {
+ return &openAIWSFallbackError{Reason: strings.TrimSpace(reason), Err: err}
+}
+
+// OpenAIWSClientCloseError 表示应以指定 WebSocket close code 主动关闭客户端连接的错误。
+type OpenAIWSClientCloseError struct {
+ statusCode coderws.StatusCode
+ reason string
+ err error
+}
+
+type openAIWSIngressTurnError struct {
+ stage string
+ cause error
+ wroteDownstream bool
+}
+
+func (e *openAIWSIngressTurnError) Error() string {
+ if e == nil {
+ return ""
+ }
+ if e.cause == nil {
+ return strings.TrimSpace(e.stage)
+ }
+ return e.cause.Error()
+}
+
+func (e *openAIWSIngressTurnError) Unwrap() error {
+ if e == nil {
+ return nil
+ }
+ return e.cause
+}
+
+func wrapOpenAIWSIngressTurnError(stage string, cause error, wroteDownstream bool) error {
+ if cause == nil {
+ return nil
+ }
+ return &openAIWSIngressTurnError{
+ stage: strings.TrimSpace(stage),
+ cause: cause,
+ wroteDownstream: wroteDownstream,
+ }
+}
+
+func isOpenAIWSIngressTurnRetryable(err error) bool {
+ var turnErr *openAIWSIngressTurnError
+ if !errors.As(err, &turnErr) || turnErr == nil {
+ return false
+ }
+ if errors.Is(turnErr.cause, context.Canceled) || errors.Is(turnErr.cause, context.DeadlineExceeded) {
+ return false
+ }
+ if turnErr.wroteDownstream {
+ return false
+ }
+ switch turnErr.stage {
+ case "write_upstream", "read_upstream":
+ return true
+ default:
+ return false
+ }
+}
+
+func openAIWSIngressTurnRetryReason(err error) string {
+ var turnErr *openAIWSIngressTurnError
+ if !errors.As(err, &turnErr) || turnErr == nil {
+ return "unknown"
+ }
+ if turnErr.stage == "" {
+ return "unknown"
+ }
+ return turnErr.stage
+}
+
+func isOpenAIWSIngressPreviousResponseNotFound(err error) bool {
+ var turnErr *openAIWSIngressTurnError
+ if !errors.As(err, &turnErr) || turnErr == nil {
+ return false
+ }
+ if strings.TrimSpace(turnErr.stage) != openAIWSIngressStagePreviousResponseNotFound {
+ return false
+ }
+ return !turnErr.wroteDownstream
+}
+
+// NewOpenAIWSClientCloseError 创建一个客户端 WS 关闭错误。
+func NewOpenAIWSClientCloseError(statusCode coderws.StatusCode, reason string, err error) error {
+ return &OpenAIWSClientCloseError{
+ statusCode: statusCode,
+ reason: strings.TrimSpace(reason),
+ err: err,
+ }
+}
+
+func (e *OpenAIWSClientCloseError) Error() string {
+ if e == nil {
+ return ""
+ }
+ if e.err == nil {
+ return fmt.Sprintf("openai ws client close: %d %s", int(e.statusCode), strings.TrimSpace(e.reason))
+ }
+ return fmt.Sprintf("openai ws client close: %d %s: %v", int(e.statusCode), strings.TrimSpace(e.reason), e.err)
+}
+
+func (e *OpenAIWSClientCloseError) Unwrap() error {
+ if e == nil {
+ return nil
+ }
+ return e.err
+}
+
+func (e *OpenAIWSClientCloseError) StatusCode() coderws.StatusCode {
+ if e == nil {
+ return coderws.StatusInternalError
+ }
+ return e.statusCode
+}
+
+func (e *OpenAIWSClientCloseError) Reason() string {
+ if e == nil {
+ return ""
+ }
+ return strings.TrimSpace(e.reason)
+}
+
+// OpenAIWSIngressHooks 定义入站 WS 每个 turn 的生命周期回调。
+type OpenAIWSIngressHooks struct {
+ BeforeTurn func(turn int) error
+ AfterTurn func(turn int, result *OpenAIForwardResult, turnErr error)
+}
+
+func normalizeOpenAIWSLogValue(value string) string {
+ trimmed := strings.TrimSpace(value)
+ if trimmed == "" {
+ return "-"
+ }
+ return openAIWSLogValueReplacer.Replace(trimmed)
+}
+
+func truncateOpenAIWSLogValue(value string, maxLen int) string {
+ normalized := normalizeOpenAIWSLogValue(value)
+ if normalized == "-" || maxLen <= 0 {
+ return normalized
+ }
+ if len(normalized) <= maxLen {
+ return normalized
+ }
+ return normalized[:maxLen] + "..."
+}
+
+func openAIWSHeaderValueForLog(headers http.Header, key string) string {
+ if headers == nil {
+ return "-"
+ }
+ return truncateOpenAIWSLogValue(headers.Get(key), openAIWSHeaderValueMaxLen)
+}
+
+func hasOpenAIWSHeader(headers http.Header, key string) bool {
+ if headers == nil {
+ return false
+ }
+ return strings.TrimSpace(headers.Get(key)) != ""
+}
+
+type openAIWSSessionHeaderResolution struct {
+ SessionID string
+ ConversationID string
+ SessionSource string
+ ConversationSource string
+}
+
+func resolveOpenAIWSSessionHeaders(c *gin.Context, promptCacheKey string) openAIWSSessionHeaderResolution {
+ resolution := openAIWSSessionHeaderResolution{
+ SessionSource: "none",
+ ConversationSource: "none",
+ }
+ if c != nil && c.Request != nil {
+ if sessionID := strings.TrimSpace(c.Request.Header.Get("session_id")); sessionID != "" {
+ resolution.SessionID = sessionID
+ resolution.SessionSource = "header_session_id"
+ }
+ if conversationID := strings.TrimSpace(c.Request.Header.Get("conversation_id")); conversationID != "" {
+ resolution.ConversationID = conversationID
+ resolution.ConversationSource = "header_conversation_id"
+ if resolution.SessionID == "" {
+ resolution.SessionID = conversationID
+ resolution.SessionSource = "header_conversation_id"
+ }
+ }
+ }
+
+ cacheKey := strings.TrimSpace(promptCacheKey)
+ if cacheKey != "" {
+ if resolution.SessionID == "" {
+ resolution.SessionID = cacheKey
+ resolution.SessionSource = "prompt_cache_key"
+ }
+ }
+ return resolution
+}
+
+func shouldLogOpenAIWSEvent(idx int, eventType string) bool {
+ if idx <= openAIWSEventLogHeadLimit {
+ return true
+ }
+ if openAIWSEventLogEveryN > 0 && idx%openAIWSEventLogEveryN == 0 {
+ return true
+ }
+ if eventType == "error" || isOpenAIWSTerminalEvent(eventType) {
+ return true
+ }
+ return false
+}
+
+func shouldLogOpenAIWSBufferedEvent(idx int) bool {
+ if idx <= openAIWSBufferLogHeadLimit {
+ return true
+ }
+ if openAIWSBufferLogEveryN > 0 && idx%openAIWSBufferLogEveryN == 0 {
+ return true
+ }
+ return false
+}
+
+func openAIWSEventMayContainModel(eventType string) bool {
+ switch eventType {
+ case "response.created",
+ "response.in_progress",
+ "response.completed",
+ "response.done",
+ "response.failed",
+ "response.incomplete",
+ "response.cancelled",
+ "response.canceled":
+ return true
+ default:
+ trimmed := strings.TrimSpace(eventType)
+ if trimmed == eventType {
+ return false
+ }
+ switch trimmed {
+ case "response.created",
+ "response.in_progress",
+ "response.completed",
+ "response.done",
+ "response.failed",
+ "response.incomplete",
+ "response.cancelled",
+ "response.canceled":
+ return true
+ default:
+ return false
+ }
+ }
+}
+
+func openAIWSEventMayContainToolCalls(eventType string) bool {
+ eventType = strings.TrimSpace(eventType)
+ if eventType == "" {
+ return false
+ }
+ if strings.Contains(eventType, "function_call") || strings.Contains(eventType, "tool_call") {
+ return true
+ }
+ switch eventType {
+ case "response.output_item.added", "response.output_item.done", "response.completed", "response.done":
+ return true
+ default:
+ return false
+ }
+}
+
+func openAIWSEventShouldParseUsage(eventType string) bool {
+ return eventType == "response.completed" || strings.TrimSpace(eventType) == "response.completed"
+}
+
+func parseOpenAIWSEventEnvelope(message []byte) (eventType string, responseID string, response gjson.Result) {
+ if len(message) == 0 {
+ return "", "", gjson.Result{}
+ }
+ values := gjson.GetManyBytes(message, "type", "response.id", "id", "response")
+ eventType = strings.TrimSpace(values[0].String())
+ if id := strings.TrimSpace(values[1].String()); id != "" {
+ responseID = id
+ } else {
+ responseID = strings.TrimSpace(values[2].String())
+ }
+ return eventType, responseID, values[3]
+}
+
+func openAIWSMessageLikelyContainsToolCalls(message []byte) bool {
+ if len(message) == 0 {
+ return false
+ }
+ return bytes.Contains(message, []byte(`"tool_calls"`)) ||
+ bytes.Contains(message, []byte(`"tool_call"`)) ||
+ bytes.Contains(message, []byte(`"function_call"`))
+}
+
+func parseOpenAIWSResponseUsageFromCompletedEvent(message []byte, usage *OpenAIUsage) {
+ if usage == nil || len(message) == 0 {
+ return
+ }
+ values := gjson.GetManyBytes(
+ message,
+ "response.usage.input_tokens",
+ "response.usage.output_tokens",
+ "response.usage.input_tokens_details.cached_tokens",
+ )
+ usage.InputTokens = int(values[0].Int())
+ usage.OutputTokens = int(values[1].Int())
+ usage.CacheReadInputTokens = int(values[2].Int())
+}
+
+func parseOpenAIWSErrorEventFields(message []byte) (code string, errType string, errMessage string) {
+ if len(message) == 0 {
+ return "", "", ""
+ }
+ values := gjson.GetManyBytes(message, "error.code", "error.type", "error.message")
+ return strings.TrimSpace(values[0].String()), strings.TrimSpace(values[1].String()), strings.TrimSpace(values[2].String())
+}
+
+func summarizeOpenAIWSErrorEventFieldsFromRaw(codeRaw, errTypeRaw, errMessageRaw string) (code string, errType string, errMessage string) {
+ code = truncateOpenAIWSLogValue(codeRaw, openAIWSLogValueMaxLen)
+ errType = truncateOpenAIWSLogValue(errTypeRaw, openAIWSLogValueMaxLen)
+ errMessage = truncateOpenAIWSLogValue(errMessageRaw, openAIWSLogValueMaxLen)
+ return code, errType, errMessage
+}
+
+func summarizeOpenAIWSErrorEventFields(message []byte) (code string, errType string, errMessage string) {
+ if len(message) == 0 {
+ return "-", "-", "-"
+ }
+ return summarizeOpenAIWSErrorEventFieldsFromRaw(parseOpenAIWSErrorEventFields(message))
+}
+
+func summarizeOpenAIWSPayloadKeySizes(payload map[string]any, topN int) string {
+ if len(payload) == 0 {
+ return "-"
+ }
+ type keySize struct {
+ Key string
+ Size int
+ }
+ sizes := make([]keySize, 0, len(payload))
+ for key, value := range payload {
+ size := estimateOpenAIWSPayloadValueSize(value, openAIWSPayloadSizeEstimateDepth)
+ sizes = append(sizes, keySize{Key: key, Size: size})
+ }
+ sort.Slice(sizes, func(i, j int) bool {
+ if sizes[i].Size == sizes[j].Size {
+ return sizes[i].Key < sizes[j].Key
+ }
+ return sizes[i].Size > sizes[j].Size
+ })
+
+ if topN <= 0 || topN > len(sizes) {
+ topN = len(sizes)
+ }
+ parts := make([]string, 0, topN)
+ for idx := 0; idx < topN; idx++ {
+ item := sizes[idx]
+ parts = append(parts, fmt.Sprintf("%s:%d", item.Key, item.Size))
+ }
+ return strings.Join(parts, ",")
+}
+
+func estimateOpenAIWSPayloadValueSize(value any, depth int) int {
+ if depth <= 0 {
+ return -1
+ }
+ switch v := value.(type) {
+ case nil:
+ return 0
+ case string:
+ return len(v)
+ case []byte:
+ return len(v)
+ case bool:
+ return 1
+ case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
+ return 8
+ case float32, float64:
+ return 8
+ case map[string]any:
+ if len(v) == 0 {
+ return 2
+ }
+ total := 2
+ count := 0
+ for key, item := range v {
+ count++
+ if count > openAIWSPayloadSizeEstimateMaxItems {
+ return -1
+ }
+ itemSize := estimateOpenAIWSPayloadValueSize(item, depth-1)
+ if itemSize < 0 {
+ return -1
+ }
+ total += len(key) + itemSize + 3
+ if total > openAIWSPayloadSizeEstimateMaxBytes {
+ return -1
+ }
+ }
+ return total
+ case []any:
+ if len(v) == 0 {
+ return 2
+ }
+ total := 2
+ limit := len(v)
+ if limit > openAIWSPayloadSizeEstimateMaxItems {
+ return -1
+ }
+ for i := 0; i < limit; i++ {
+ itemSize := estimateOpenAIWSPayloadValueSize(v[i], depth-1)
+ if itemSize < 0 {
+ return -1
+ }
+ total += itemSize + 1
+ if total > openAIWSPayloadSizeEstimateMaxBytes {
+ return -1
+ }
+ }
+ return total
+ default:
+ raw, err := json.Marshal(v)
+ if err != nil {
+ return -1
+ }
+ if len(raw) > openAIWSPayloadSizeEstimateMaxBytes {
+ return -1
+ }
+ return len(raw)
+ }
+}
+
+func openAIWSPayloadString(payload map[string]any, key string) string {
+ if len(payload) == 0 {
+ return ""
+ }
+ raw, ok := payload[key]
+ if !ok {
+ return ""
+ }
+ switch v := raw.(type) {
+ case nil:
+ return ""
+ case string:
+ return strings.TrimSpace(v)
+ case []byte:
+ return strings.TrimSpace(string(v))
+ default:
+ return ""
+ }
+}
+
+func openAIWSPayloadStringFromRaw(payload []byte, key string) string {
+ if len(payload) == 0 || strings.TrimSpace(key) == "" {
+ return ""
+ }
+ return strings.TrimSpace(gjson.GetBytes(payload, key).String())
+}
+
+func openAIWSPayloadBoolFromRaw(payload []byte, key string, defaultValue bool) bool {
+ if len(payload) == 0 || strings.TrimSpace(key) == "" {
+ return defaultValue
+ }
+ value := gjson.GetBytes(payload, key)
+ if !value.Exists() {
+ return defaultValue
+ }
+ if value.Type != gjson.True && value.Type != gjson.False {
+ return defaultValue
+ }
+ return value.Bool()
+}
+
+func openAIWSSessionHashesFromID(sessionID string) (string, string) {
+ return deriveOpenAISessionHashes(sessionID)
+}
+
+func extractOpenAIWSImageURL(value any) string {
+ switch v := value.(type) {
+ case string:
+ return strings.TrimSpace(v)
+ case map[string]any:
+ if raw, ok := v["url"].(string); ok {
+ return strings.TrimSpace(raw)
+ }
+ }
+ return ""
+}
+
+func summarizeOpenAIWSInput(input any) string {
+ items, ok := input.([]any)
+ if !ok || len(items) == 0 {
+ return "-"
+ }
+
+ itemCount := len(items)
+ textChars := 0
+ imageDataURLs := 0
+ imageDataURLChars := 0
+ imageRemoteURLs := 0
+
+ handleContentItem := func(contentItem map[string]any) {
+ contentType, _ := contentItem["type"].(string)
+ switch strings.TrimSpace(contentType) {
+ case "input_text", "output_text", "text":
+ if text, ok := contentItem["text"].(string); ok {
+ textChars += len(text)
+ }
+ case "input_image":
+ imageURL := extractOpenAIWSImageURL(contentItem["image_url"])
+ if imageURL == "" {
+ return
+ }
+ if strings.HasPrefix(strings.ToLower(imageURL), "data:image/") {
+ imageDataURLs++
+ imageDataURLChars += len(imageURL)
+ return
+ }
+ imageRemoteURLs++
+ }
+ }
+
+ handleInputItem := func(inputItem map[string]any) {
+ if content, ok := inputItem["content"].([]any); ok {
+ for _, rawContent := range content {
+ contentItem, ok := rawContent.(map[string]any)
+ if !ok {
+ continue
+ }
+ handleContentItem(contentItem)
+ }
+ return
+ }
+
+ itemType, _ := inputItem["type"].(string)
+ switch strings.TrimSpace(itemType) {
+ case "input_text", "output_text", "text":
+ if text, ok := inputItem["text"].(string); ok {
+ textChars += len(text)
+ }
+ case "input_image":
+ imageURL := extractOpenAIWSImageURL(inputItem["image_url"])
+ if imageURL == "" {
+ return
+ }
+ if strings.HasPrefix(strings.ToLower(imageURL), "data:image/") {
+ imageDataURLs++
+ imageDataURLChars += len(imageURL)
+ return
+ }
+ imageRemoteURLs++
+ }
+ }
+
+ for _, rawItem := range items {
+ inputItem, ok := rawItem.(map[string]any)
+ if !ok {
+ continue
+ }
+ handleInputItem(inputItem)
+ }
+
+ return fmt.Sprintf(
+ "items=%d,text_chars=%d,image_data_urls=%d,image_data_url_chars=%d,image_remote_urls=%d",
+ itemCount,
+ textChars,
+ imageDataURLs,
+ imageDataURLChars,
+ imageRemoteURLs,
+ )
+}
+
+func dropOpenAIWSPayloadKey(payload map[string]any, key string, removed *[]string) {
+ if len(payload) == 0 || strings.TrimSpace(key) == "" {
+ return
+ }
+ if _, exists := payload[key]; !exists {
+ return
+ }
+ delete(payload, key)
+ *removed = append(*removed, key)
+}
+
+// applyOpenAIWSRetryPayloadStrategy 在 WS 连续失败时仅移除无语义字段,
+// 避免重试成功却改变原始请求语义。
+// 注意:prompt_cache_key 不应在重试中移除;它常用于会话稳定标识(session_id 兜底)。
+func applyOpenAIWSRetryPayloadStrategy(payload map[string]any, attempt int) (strategy string, removedKeys []string) {
+ if len(payload) == 0 {
+ return "empty", nil
+ }
+ if attempt <= 1 {
+ return "full", nil
+ }
+
+ removed := make([]string, 0, 2)
+ if attempt >= 2 {
+ dropOpenAIWSPayloadKey(payload, "include", &removed)
+ }
+
+ if len(removed) == 0 {
+ return "full", nil
+ }
+ sort.Strings(removed)
+ return "trim_optional_fields", removed
+}
+
+func logOpenAIWSModeInfo(format string, args ...any) {
+ logger.LegacyPrintf("service.openai_gateway", "[OpenAI WS Mode][openai_ws_mode=true] "+format, args...)
+}
+
+func isOpenAIWSModeDebugEnabled() bool {
+ return logger.L().Core().Enabled(zap.DebugLevel)
+}
+
+func logOpenAIWSModeDebug(format string, args ...any) {
+ if !isOpenAIWSModeDebugEnabled() {
+ return
+ }
+ logger.LegacyPrintf("service.openai_gateway", "[debug] [OpenAI WS Mode][openai_ws_mode=true] "+format, args...)
+}
+
+func logOpenAIWSBindResponseAccountWarn(groupID, accountID int64, responseID string, err error) {
+ if err == nil {
+ return
+ }
+ logger.L().Warn(
+ "openai.ws_bind_response_account_failed",
+ zap.Int64("group_id", groupID),
+ zap.Int64("account_id", accountID),
+ zap.String("response_id", truncateOpenAIWSLogValue(responseID, openAIWSIDValueMaxLen)),
+ zap.Error(err),
+ )
+}
+
+func summarizeOpenAIWSReadCloseError(err error) (status string, reason string) {
+ if err == nil {
+ return "-", "-"
+ }
+ statusCode := coderws.CloseStatus(err)
+ if statusCode == -1 {
+ return "-", "-"
+ }
+ closeStatus := fmt.Sprintf("%d(%s)", int(statusCode), statusCode.String())
+ closeReason := "-"
+ var closeErr coderws.CloseError
+ if errors.As(err, &closeErr) {
+ reasonText := strings.TrimSpace(closeErr.Reason)
+ if reasonText != "" {
+ closeReason = normalizeOpenAIWSLogValue(reasonText)
+ }
+ }
+ return normalizeOpenAIWSLogValue(closeStatus), closeReason
+}
+
+func unwrapOpenAIWSDialBaseError(err error) error {
+ if err == nil {
+ return nil
+ }
+ var dialErr *openAIWSDialError
+ if errors.As(err, &dialErr) && dialErr != nil && dialErr.Err != nil {
+ return dialErr.Err
+ }
+ return err
+}
+
+func openAIWSDialRespHeaderForLog(err error, key string) string {
+ var dialErr *openAIWSDialError
+ if !errors.As(err, &dialErr) || dialErr == nil || dialErr.ResponseHeaders == nil {
+ return "-"
+ }
+ return truncateOpenAIWSLogValue(dialErr.ResponseHeaders.Get(key), openAIWSHeaderValueMaxLen)
+}
+
+func classifyOpenAIWSDialError(err error) string {
+ if err == nil {
+ return "-"
+ }
+ baseErr := unwrapOpenAIWSDialBaseError(err)
+ if baseErr == nil {
+ return "-"
+ }
+ if errors.Is(baseErr, context.DeadlineExceeded) {
+ return "ctx_deadline_exceeded"
+ }
+ if errors.Is(baseErr, context.Canceled) {
+ return "ctx_canceled"
+ }
+ var netErr net.Error
+ if errors.As(baseErr, &netErr) && netErr.Timeout() {
+ return "net_timeout"
+ }
+ if status := coderws.CloseStatus(baseErr); status != -1 {
+ return normalizeOpenAIWSLogValue(fmt.Sprintf("ws_close_%d", int(status)))
+ }
+ message := strings.ToLower(strings.TrimSpace(baseErr.Error()))
+ switch {
+ case strings.Contains(message, "handshake not finished"):
+ return "handshake_not_finished"
+ case strings.Contains(message, "bad handshake"):
+ return "bad_handshake"
+ case strings.Contains(message, "connection refused"):
+ return "connection_refused"
+ case strings.Contains(message, "no such host"):
+ return "dns_not_found"
+ case strings.Contains(message, "tls"):
+ return "tls_error"
+ case strings.Contains(message, "i/o timeout"):
+ return "io_timeout"
+ case strings.Contains(message, "context deadline exceeded"):
+ return "ctx_deadline_exceeded"
+ default:
+ return "dial_error"
+ }
+}
+
+func summarizeOpenAIWSDialError(err error) (
+ statusCode int,
+ dialClass string,
+ closeStatus string,
+ closeReason string,
+ respServer string,
+ respVia string,
+ respCFRay string,
+ respRequestID string,
+) {
+ dialClass = "-"
+ closeStatus = "-"
+ closeReason = "-"
+ respServer = "-"
+ respVia = "-"
+ respCFRay = "-"
+ respRequestID = "-"
+ if err == nil {
+ return
+ }
+ var dialErr *openAIWSDialError
+ if errors.As(err, &dialErr) && dialErr != nil {
+ statusCode = dialErr.StatusCode
+ respServer = openAIWSDialRespHeaderForLog(err, "server")
+ respVia = openAIWSDialRespHeaderForLog(err, "via")
+ respCFRay = openAIWSDialRespHeaderForLog(err, "cf-ray")
+ respRequestID = openAIWSDialRespHeaderForLog(err, "x-request-id")
+ }
+ dialClass = normalizeOpenAIWSLogValue(classifyOpenAIWSDialError(err))
+ closeStatus, closeReason = summarizeOpenAIWSReadCloseError(unwrapOpenAIWSDialBaseError(err))
+ return
+}
+
+func isOpenAIWSClientDisconnectError(err error) bool {
+ if err == nil {
+ return false
+ }
+ if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) || errors.Is(err, context.Canceled) {
+ return true
+ }
+ switch coderws.CloseStatus(err) {
+ case coderws.StatusNormalClosure, coderws.StatusGoingAway, coderws.StatusNoStatusRcvd, coderws.StatusAbnormalClosure:
+ return true
+ }
+ message := strings.ToLower(strings.TrimSpace(err.Error()))
+ if message == "" {
+ return false
+ }
+ return strings.Contains(message, "failed to read frame header: eof") ||
+ strings.Contains(message, "unexpected eof") ||
+ strings.Contains(message, "use of closed network connection") ||
+ strings.Contains(message, "connection reset by peer") ||
+ strings.Contains(message, "broken pipe")
+}
+
+func classifyOpenAIWSReadFallbackReason(err error) string {
+ if err == nil {
+ return "read_event"
+ }
+ switch coderws.CloseStatus(err) {
+ case coderws.StatusPolicyViolation:
+ return "policy_violation"
+ case coderws.StatusMessageTooBig:
+ return "message_too_big"
+ default:
+ return "read_event"
+ }
+}
+
+func sortedKeys(m map[string]any) []string {
+ if len(m) == 0 {
+ return nil
+ }
+ keys := make([]string, 0, len(m))
+ for k := range m {
+ keys = append(keys, k)
+ }
+ sort.Strings(keys)
+ return keys
+}
+
+func (s *OpenAIGatewayService) getOpenAIWSConnPool() *openAIWSConnPool {
+ if s == nil {
+ return nil
+ }
+ s.openaiWSPoolOnce.Do(func() {
+ if s.openaiWSPool == nil {
+ s.openaiWSPool = newOpenAIWSConnPool(s.cfg)
+ }
+ })
+ return s.openaiWSPool
+}
+
+func (s *OpenAIGatewayService) SnapshotOpenAIWSPoolMetrics() OpenAIWSPoolMetricsSnapshot {
+ pool := s.getOpenAIWSConnPool()
+ if pool == nil {
+ return OpenAIWSPoolMetricsSnapshot{}
+ }
+ return pool.SnapshotMetrics()
+}
+
+type OpenAIWSPerformanceMetricsSnapshot struct {
+ Pool OpenAIWSPoolMetricsSnapshot `json:"pool"`
+ Retry OpenAIWSRetryMetricsSnapshot `json:"retry"`
+ Transport OpenAIWSTransportMetricsSnapshot `json:"transport"`
+}
+
+func (s *OpenAIGatewayService) SnapshotOpenAIWSPerformanceMetrics() OpenAIWSPerformanceMetricsSnapshot {
+ pool := s.getOpenAIWSConnPool()
+ snapshot := OpenAIWSPerformanceMetricsSnapshot{
+ Retry: s.SnapshotOpenAIWSRetryMetrics(),
+ }
+ if pool == nil {
+ return snapshot
+ }
+ snapshot.Pool = pool.SnapshotMetrics()
+ snapshot.Transport = pool.SnapshotTransportMetrics()
+ return snapshot
+}
+
+func (s *OpenAIGatewayService) getOpenAIWSStateStore() OpenAIWSStateStore {
+ if s == nil {
+ return nil
+ }
+ s.openaiWSStateStoreOnce.Do(func() {
+ if s.openaiWSStateStore == nil {
+ s.openaiWSStateStore = NewOpenAIWSStateStore(s.cache)
+ }
+ })
+ return s.openaiWSStateStore
+}
+
+func (s *OpenAIGatewayService) openAIWSResponseStickyTTL() time.Duration {
+ if s != nil && s.cfg != nil {
+ seconds := s.cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds
+ if seconds > 0 {
+ return time.Duration(seconds) * time.Second
+ }
+ }
+ return time.Hour
+}
+
+func (s *OpenAIGatewayService) openAIWSIngressPreviousResponseRecoveryEnabled() bool {
+ if s != nil && s.cfg != nil {
+ return s.cfg.Gateway.OpenAIWS.IngressPreviousResponseRecoveryEnabled
+ }
+ return true
+}
+
+func (s *OpenAIGatewayService) openAIWSReadTimeout() time.Duration {
+ if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.ReadTimeoutSeconds > 0 {
+ return time.Duration(s.cfg.Gateway.OpenAIWS.ReadTimeoutSeconds) * time.Second
+ }
+ return 15 * time.Minute
+}
+
+func (s *OpenAIGatewayService) openAIWSWriteTimeout() time.Duration {
+ if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.WriteTimeoutSeconds > 0 {
+ return time.Duration(s.cfg.Gateway.OpenAIWS.WriteTimeoutSeconds) * time.Second
+ }
+ return 2 * time.Minute
+}
+
+func (s *OpenAIGatewayService) openAIWSEventFlushBatchSize() int {
+ if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.EventFlushBatchSize > 0 {
+ return s.cfg.Gateway.OpenAIWS.EventFlushBatchSize
+ }
+ return openAIWSEventFlushBatchSizeDefault
+}
+
+func (s *OpenAIGatewayService) openAIWSEventFlushInterval() time.Duration {
+ if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.EventFlushIntervalMS >= 0 {
+ if s.cfg.Gateway.OpenAIWS.EventFlushIntervalMS == 0 {
+ return 0
+ }
+ return time.Duration(s.cfg.Gateway.OpenAIWS.EventFlushIntervalMS) * time.Millisecond
+ }
+ return openAIWSEventFlushIntervalDefault
+}
+
+func (s *OpenAIGatewayService) openAIWSPayloadLogSampleRate() float64 {
+ if s != nil && s.cfg != nil {
+ rate := s.cfg.Gateway.OpenAIWS.PayloadLogSampleRate
+ if rate < 0 {
+ return 0
+ }
+ if rate > 1 {
+ return 1
+ }
+ return rate
+ }
+ return openAIWSPayloadLogSampleDefault
+}
+
+func (s *OpenAIGatewayService) shouldLogOpenAIWSPayloadSchema(attempt int) bool {
+ // 首次尝试保留一条完整 payload_schema 便于排障。
+ if attempt <= 1 {
+ return true
+ }
+ rate := s.openAIWSPayloadLogSampleRate()
+ if rate <= 0 {
+ return false
+ }
+ if rate >= 1 {
+ return true
+ }
+ return rand.Float64() < rate
+}
+
+func (s *OpenAIGatewayService) shouldEmitOpenAIWSPayloadSchema(attempt int) bool {
+ if !s.shouldLogOpenAIWSPayloadSchema(attempt) {
+ return false
+ }
+ return logger.L().Core().Enabled(zap.DebugLevel)
+}
+
+func (s *OpenAIGatewayService) openAIWSDialTimeout() time.Duration {
+ if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.DialTimeoutSeconds > 0 {
+ return time.Duration(s.cfg.Gateway.OpenAIWS.DialTimeoutSeconds) * time.Second
+ }
+ return 10 * time.Second
+}
+
+func (s *OpenAIGatewayService) openAIWSAcquireTimeout() time.Duration {
+ // Acquire 覆盖“连接复用命中/排队/新建连接”三个阶段。
+ // 这里不再叠加 write_timeout,避免高并发排队时把 TTFT 长尾拉到分钟级。
+ dial := s.openAIWSDialTimeout()
+ if dial <= 0 {
+ dial = 10 * time.Second
+ }
+ return dial + 2*time.Second
+}
+
+func (s *OpenAIGatewayService) buildOpenAIResponsesWSURL(account *Account) (string, error) {
+ if account == nil {
+ return "", errors.New("account is nil")
+ }
+ var targetURL string
+ switch account.Type {
+ case AccountTypeOAuth:
+ targetURL = chatgptCodexURL
+ case AccountTypeAPIKey:
+ baseURL := account.GetOpenAIBaseURL()
+ if baseURL == "" {
+ targetURL = openaiPlatformAPIURL
+ } else {
+ validatedURL, err := s.validateUpstreamBaseURL(baseURL)
+ if err != nil {
+ return "", err
+ }
+ targetURL = buildOpenAIResponsesURL(validatedURL)
+ }
+ default:
+ targetURL = openaiPlatformAPIURL
+ }
+
+ parsed, err := url.Parse(strings.TrimSpace(targetURL))
+ if err != nil {
+ return "", fmt.Errorf("invalid target url: %w", err)
+ }
+ switch strings.ToLower(parsed.Scheme) {
+ case "https":
+ parsed.Scheme = "wss"
+ case "http":
+ parsed.Scheme = "ws"
+ case "wss", "ws":
+ // 保持不变
+ default:
+ return "", fmt.Errorf("unsupported scheme for ws: %s", parsed.Scheme)
+ }
+ return parsed.String(), nil
+}
+
+func (s *OpenAIGatewayService) buildOpenAIWSHeaders(
+ c *gin.Context,
+ account *Account,
+ token string,
+ decision OpenAIWSProtocolDecision,
+ isCodexCLI bool,
+ turnState string,
+ turnMetadata string,
+ promptCacheKey string,
+) (http.Header, openAIWSSessionHeaderResolution) {
+ headers := make(http.Header)
+ headers.Set("authorization", "Bearer "+token)
+
+ sessionResolution := resolveOpenAIWSSessionHeaders(c, promptCacheKey)
+ if c != nil && c.Request != nil {
+ if v := strings.TrimSpace(c.Request.Header.Get("accept-language")); v != "" {
+ headers.Set("accept-language", v)
+ }
+ }
+ if sessionResolution.SessionID != "" {
+ headers.Set("session_id", sessionResolution.SessionID)
+ }
+ if sessionResolution.ConversationID != "" {
+ headers.Set("conversation_id", sessionResolution.ConversationID)
+ }
+ if state := strings.TrimSpace(turnState); state != "" {
+ headers.Set(openAIWSTurnStateHeader, state)
+ }
+ if metadata := strings.TrimSpace(turnMetadata); metadata != "" {
+ headers.Set(openAIWSTurnMetadataHeader, metadata)
+ }
+
+ if account != nil && account.Type == AccountTypeOAuth {
+ if chatgptAccountID := account.GetChatGPTAccountID(); chatgptAccountID != "" {
+ headers.Set("chatgpt-account-id", chatgptAccountID)
+ }
+ if isCodexCLI {
+ headers.Set("originator", "codex_cli_rs")
+ } else {
+ headers.Set("originator", "opencode")
+ }
+ }
+
+ betaValue := openAIWSBetaV2Value
+ if decision.Transport == OpenAIUpstreamTransportResponsesWebsocket {
+ betaValue = openAIWSBetaV1Value
+ }
+ headers.Set("OpenAI-Beta", betaValue)
+
+ customUA := ""
+ if account != nil {
+ customUA = account.GetOpenAIUserAgent()
+ }
+ if strings.TrimSpace(customUA) != "" {
+ headers.Set("user-agent", customUA)
+ } else if c != nil {
+ if ua := strings.TrimSpace(c.GetHeader("User-Agent")); ua != "" {
+ headers.Set("user-agent", ua)
+ }
+ }
+ if s != nil && s.cfg != nil && s.cfg.Gateway.ForceCodexCLI {
+ headers.Set("user-agent", codexCLIUserAgent)
+ }
+ if account != nil && account.Type == AccountTypeOAuth && !openai.IsCodexCLIRequest(headers.Get("user-agent")) {
+ headers.Set("user-agent", codexCLIUserAgent)
+ }
+
+ return headers, sessionResolution
+}
+
+func (s *OpenAIGatewayService) buildOpenAIWSCreatePayload(reqBody map[string]any, account *Account) map[string]any {
+ // OpenAI WS Mode 协议:response.create 字段与 HTTP /responses 基本一致。
+ // 保留 stream 字段(与 Codex CLI 一致),仅移除 background。
+ payload := make(map[string]any, len(reqBody)+1)
+ for k, v := range reqBody {
+ payload[k] = v
+ }
+
+ delete(payload, "background")
+ if _, exists := payload["stream"]; !exists {
+ payload["stream"] = true
+ }
+ payload["type"] = "response.create"
+
+ // OAuth 默认保持 store=false,避免误依赖服务端历史。
+ if account != nil && account.Type == AccountTypeOAuth && !s.isOpenAIWSStoreRecoveryAllowed(account) {
+ payload["store"] = false
+ }
+ return payload
+}
+
+func setOpenAIWSTurnMetadata(payload map[string]any, turnMetadata string) {
+ if len(payload) == 0 {
+ return
+ }
+ metadata := strings.TrimSpace(turnMetadata)
+ if metadata == "" {
+ return
+ }
+
+ switch existing := payload["client_metadata"].(type) {
+ case map[string]any:
+ existing[openAIWSTurnMetadataHeader] = metadata
+ payload["client_metadata"] = existing
+ case map[string]string:
+ next := make(map[string]any, len(existing)+1)
+ for k, v := range existing {
+ next[k] = v
+ }
+ next[openAIWSTurnMetadataHeader] = metadata
+ payload["client_metadata"] = next
+ default:
+ payload["client_metadata"] = map[string]any{
+ openAIWSTurnMetadataHeader: metadata,
+ }
+ }
+}
+
+func (s *OpenAIGatewayService) isOpenAIWSStoreRecoveryAllowed(account *Account) bool {
+ if account != nil && account.IsOpenAIWSAllowStoreRecoveryEnabled() {
+ return true
+ }
+ if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.AllowStoreRecovery {
+ return true
+ }
+ return false
+}
+
+func (s *OpenAIGatewayService) isOpenAIWSStoreDisabledInRequest(reqBody map[string]any, account *Account) bool {
+ if account != nil && account.Type == AccountTypeOAuth && !s.isOpenAIWSStoreRecoveryAllowed(account) {
+ return true
+ }
+ if len(reqBody) == 0 {
+ return false
+ }
+ rawStore, ok := reqBody["store"]
+ if !ok {
+ return false
+ }
+ storeEnabled, ok := rawStore.(bool)
+ if !ok {
+ return false
+ }
+ return !storeEnabled
+}
+
+func (s *OpenAIGatewayService) isOpenAIWSStoreDisabledInRequestRaw(reqBody []byte, account *Account) bool {
+ if account != nil && account.Type == AccountTypeOAuth && !s.isOpenAIWSStoreRecoveryAllowed(account) {
+ return true
+ }
+ if len(reqBody) == 0 {
+ return false
+ }
+ storeValue := gjson.GetBytes(reqBody, "store")
+ if !storeValue.Exists() {
+ return false
+ }
+ if storeValue.Type != gjson.True && storeValue.Type != gjson.False {
+ return false
+ }
+ return !storeValue.Bool()
+}
+
+func (s *OpenAIGatewayService) openAIWSStoreDisabledConnMode() string {
+ if s == nil || s.cfg == nil {
+ return openAIWSStoreDisabledConnModeStrict
+ }
+ mode := strings.ToLower(strings.TrimSpace(s.cfg.Gateway.OpenAIWS.StoreDisabledConnMode))
+ switch mode {
+ case openAIWSStoreDisabledConnModeStrict, openAIWSStoreDisabledConnModeAdaptive, openAIWSStoreDisabledConnModeOff:
+ return mode
+ case "":
+ // 兼容旧配置:仅配置了布尔开关时按旧语义推导。
+ if s.cfg.Gateway.OpenAIWS.StoreDisabledForceNewConn {
+ return openAIWSStoreDisabledConnModeStrict
+ }
+ return openAIWSStoreDisabledConnModeOff
+ default:
+ return openAIWSStoreDisabledConnModeStrict
+ }
+}
+
+func shouldForceNewConnOnStoreDisabled(mode, lastFailureReason string) bool {
+ switch mode {
+ case openAIWSStoreDisabledConnModeOff:
+ return false
+ case openAIWSStoreDisabledConnModeAdaptive:
+ reason := strings.TrimPrefix(strings.TrimSpace(lastFailureReason), "prewarm_")
+ switch reason {
+ case "policy_violation", "message_too_big", "auth_failed", "write_request", "write":
+ return true
+ default:
+ return false
+ }
+ default:
+ return true
+ }
+}
+
+func dropPreviousResponseIDFromRawPayload(payload []byte) ([]byte, bool, error) {
+ return dropPreviousResponseIDFromRawPayloadWithDeleteFn(payload, sjson.DeleteBytes)
+}
+
+func dropPreviousResponseIDFromRawPayloadWithDeleteFn(
+ payload []byte,
+ deleteFn func([]byte, string) ([]byte, error),
+) ([]byte, bool, error) {
+ if len(payload) == 0 {
+ return payload, false, nil
+ }
+ if !gjson.GetBytes(payload, "previous_response_id").Exists() {
+ return payload, false, nil
+ }
+ if deleteFn == nil {
+ deleteFn = sjson.DeleteBytes
+ }
+
+ updated := payload
+ for i := 0; i < openAIWSMaxPrevResponseIDDeletePasses &&
+ gjson.GetBytes(updated, "previous_response_id").Exists(); i++ {
+ next, err := deleteFn(updated, "previous_response_id")
+ if err != nil {
+ return payload, false, err
+ }
+ updated = next
+ }
+ return updated, !gjson.GetBytes(updated, "previous_response_id").Exists(), nil
+}
+
+func setPreviousResponseIDToRawPayload(payload []byte, previousResponseID string) ([]byte, error) {
+ normalizedPrevID := strings.TrimSpace(previousResponseID)
+ if len(payload) == 0 || normalizedPrevID == "" {
+ return payload, nil
+ }
+ updated, err := sjson.SetBytes(payload, "previous_response_id", normalizedPrevID)
+ if err == nil {
+ return updated, nil
+ }
+
+ var reqBody map[string]any
+ if unmarshalErr := json.Unmarshal(payload, &reqBody); unmarshalErr != nil {
+ return nil, err
+ }
+ reqBody["previous_response_id"] = normalizedPrevID
+ rebuilt, marshalErr := json.Marshal(reqBody)
+ if marshalErr != nil {
+ return nil, marshalErr
+ }
+ return rebuilt, nil
+}
+
+func shouldInferIngressFunctionCallOutputPreviousResponseID(
+ storeDisabled bool,
+ turn int,
+ hasFunctionCallOutput bool,
+ currentPreviousResponseID string,
+ expectedPreviousResponseID string,
+) bool {
+ if !storeDisabled || turn <= 1 || !hasFunctionCallOutput {
+ return false
+ }
+ if strings.TrimSpace(currentPreviousResponseID) != "" {
+ return false
+ }
+ return strings.TrimSpace(expectedPreviousResponseID) != ""
+}
+
+func alignStoreDisabledPreviousResponseID(
+ payload []byte,
+ expectedPreviousResponseID string,
+) ([]byte, bool, error) {
+ if len(payload) == 0 {
+ return payload, false, nil
+ }
+ expected := strings.TrimSpace(expectedPreviousResponseID)
+ if expected == "" {
+ return payload, false, nil
+ }
+ current := openAIWSPayloadStringFromRaw(payload, "previous_response_id")
+ if current == "" || current == expected {
+ return payload, false, nil
+ }
+
+ withoutPrev, removed, dropErr := dropPreviousResponseIDFromRawPayload(payload)
+ if dropErr != nil {
+ return payload, false, dropErr
+ }
+ if !removed {
+ return payload, false, nil
+ }
+ updated, setErr := setPreviousResponseIDToRawPayload(withoutPrev, expected)
+ if setErr != nil {
+ return payload, false, setErr
+ }
+ return updated, true, nil
+}
+
+func cloneOpenAIWSPayloadBytes(payload []byte) []byte {
+ if len(payload) == 0 {
+ return nil
+ }
+ cloned := make([]byte, len(payload))
+ copy(cloned, payload)
+ return cloned
+}
+
+func cloneOpenAIWSRawMessages(items []json.RawMessage) []json.RawMessage {
+ if items == nil {
+ return nil
+ }
+ cloned := make([]json.RawMessage, 0, len(items))
+ for idx := range items {
+ cloned = append(cloned, json.RawMessage(cloneOpenAIWSPayloadBytes(items[idx])))
+ }
+ return cloned
+}
+
+func normalizeOpenAIWSJSONForCompare(raw []byte) ([]byte, error) {
+ trimmed := bytes.TrimSpace(raw)
+ if len(trimmed) == 0 {
+ return nil, errors.New("json is empty")
+ }
+ var decoded any
+ if err := json.Unmarshal(trimmed, &decoded); err != nil {
+ return nil, err
+ }
+ return json.Marshal(decoded)
+}
+
+func normalizeOpenAIWSJSONForCompareOrRaw(raw []byte) []byte {
+ normalized, err := normalizeOpenAIWSJSONForCompare(raw)
+ if err != nil {
+ return bytes.TrimSpace(raw)
+ }
+ return normalized
+}
+
+func normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(payload []byte) ([]byte, error) {
+ if len(payload) == 0 {
+ return nil, errors.New("payload is empty")
+ }
+ var decoded map[string]any
+ if err := json.Unmarshal(payload, &decoded); err != nil {
+ return nil, err
+ }
+ delete(decoded, "input")
+ delete(decoded, "previous_response_id")
+ return json.Marshal(decoded)
+}
+
+func openAIWSExtractNormalizedInputSequence(payload []byte) ([]json.RawMessage, bool, error) {
+ if len(payload) == 0 {
+ return nil, false, nil
+ }
+ inputValue := gjson.GetBytes(payload, "input")
+ if !inputValue.Exists() {
+ return nil, false, nil
+ }
+ if inputValue.Type == gjson.JSON {
+ raw := strings.TrimSpace(inputValue.Raw)
+ if strings.HasPrefix(raw, "[") {
+ var items []json.RawMessage
+ if err := json.Unmarshal([]byte(raw), &items); err != nil {
+ return nil, true, err
+ }
+ return items, true, nil
+ }
+ return []json.RawMessage{json.RawMessage(raw)}, true, nil
+ }
+ if inputValue.Type == gjson.String {
+ encoded, _ := json.Marshal(inputValue.String())
+ return []json.RawMessage{encoded}, true, nil
+ }
+ return []json.RawMessage{json.RawMessage(inputValue.Raw)}, true, nil
+}
+
+func openAIWSInputIsPrefixExtended(previousPayload, currentPayload []byte) (bool, error) {
+ previousItems, previousExists, prevErr := openAIWSExtractNormalizedInputSequence(previousPayload)
+ if prevErr != nil {
+ return false, prevErr
+ }
+ currentItems, currentExists, currentErr := openAIWSExtractNormalizedInputSequence(currentPayload)
+ if currentErr != nil {
+ return false, currentErr
+ }
+ if !previousExists && !currentExists {
+ return true, nil
+ }
+ if !previousExists {
+ return len(currentItems) == 0, nil
+ }
+ if !currentExists {
+ return len(previousItems) == 0, nil
+ }
+ if len(currentItems) < len(previousItems) {
+ return false, nil
+ }
+
+ for idx := range previousItems {
+ previousNormalized := normalizeOpenAIWSJSONForCompareOrRaw(previousItems[idx])
+ currentNormalized := normalizeOpenAIWSJSONForCompareOrRaw(currentItems[idx])
+ if !bytes.Equal(previousNormalized, currentNormalized) {
+ return false, nil
+ }
+ }
+ return true, nil
+}
+
+func openAIWSRawItemsHasPrefix(items []json.RawMessage, prefix []json.RawMessage) bool {
+ if len(prefix) == 0 {
+ return true
+ }
+ if len(items) < len(prefix) {
+ return false
+ }
+ for idx := range prefix {
+ previousNormalized := normalizeOpenAIWSJSONForCompareOrRaw(prefix[idx])
+ currentNormalized := normalizeOpenAIWSJSONForCompareOrRaw(items[idx])
+ if !bytes.Equal(previousNormalized, currentNormalized) {
+ return false
+ }
+ }
+ return true
+}
+
+func buildOpenAIWSReplayInputSequence(
+ previousFullInput []json.RawMessage,
+ previousFullInputExists bool,
+ currentPayload []byte,
+ hasPreviousResponseID bool,
+) ([]json.RawMessage, bool, error) {
+ currentItems, currentExists, currentErr := openAIWSExtractNormalizedInputSequence(currentPayload)
+ if currentErr != nil {
+ return nil, false, currentErr
+ }
+ if !hasPreviousResponseID {
+ return cloneOpenAIWSRawMessages(currentItems), currentExists, nil
+ }
+ if !previousFullInputExists {
+ return cloneOpenAIWSRawMessages(currentItems), currentExists, nil
+ }
+ if !currentExists || len(currentItems) == 0 {
+ return cloneOpenAIWSRawMessages(previousFullInput), true, nil
+ }
+ if openAIWSRawItemsHasPrefix(currentItems, previousFullInput) {
+ return cloneOpenAIWSRawMessages(currentItems), true, nil
+ }
+ merged := make([]json.RawMessage, 0, len(previousFullInput)+len(currentItems))
+ merged = append(merged, cloneOpenAIWSRawMessages(previousFullInput)...)
+ merged = append(merged, cloneOpenAIWSRawMessages(currentItems)...)
+ return merged, true, nil
+}
+
+func setOpenAIWSPayloadInputSequence(
+ payload []byte,
+ fullInput []json.RawMessage,
+ fullInputExists bool,
+) ([]byte, error) {
+ if !fullInputExists {
+ return payload, nil
+ }
+ // Preserve [] vs null semantics when input exists but is empty.
+ inputForMarshal := fullInput
+ if inputForMarshal == nil {
+ inputForMarshal = []json.RawMessage{}
+ }
+ inputRaw, marshalErr := json.Marshal(inputForMarshal)
+ if marshalErr != nil {
+ return nil, marshalErr
+ }
+ return sjson.SetRawBytes(payload, "input", inputRaw)
+}
+
+func shouldKeepIngressPreviousResponseID(
+ previousPayload []byte,
+ currentPayload []byte,
+ lastTurnResponseID string,
+ hasFunctionCallOutput bool,
+) (bool, string, error) {
+ if hasFunctionCallOutput {
+ return true, "has_function_call_output", nil
+ }
+ currentPreviousResponseID := strings.TrimSpace(openAIWSPayloadStringFromRaw(currentPayload, "previous_response_id"))
+ if currentPreviousResponseID == "" {
+ return false, "missing_previous_response_id", nil
+ }
+ expectedPreviousResponseID := strings.TrimSpace(lastTurnResponseID)
+ if expectedPreviousResponseID == "" {
+ return false, "missing_last_turn_response_id", nil
+ }
+ if currentPreviousResponseID != expectedPreviousResponseID {
+ return false, "previous_response_id_mismatch", nil
+ }
+ if len(previousPayload) == 0 {
+ return false, "missing_previous_turn_payload", nil
+ }
+
+ previousComparable, previousComparableErr := normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(previousPayload)
+ if previousComparableErr != nil {
+ return false, "non_input_compare_error", previousComparableErr
+ }
+ currentComparable, currentComparableErr := normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(currentPayload)
+ if currentComparableErr != nil {
+ return false, "non_input_compare_error", currentComparableErr
+ }
+ if !bytes.Equal(previousComparable, currentComparable) {
+ return false, "non_input_changed", nil
+ }
+ return true, "strict_incremental_ok", nil
+}
+
+type openAIWSIngressPreviousTurnStrictState struct {
+ nonInputComparable []byte
+}
+
+func buildOpenAIWSIngressPreviousTurnStrictState(payload []byte) (*openAIWSIngressPreviousTurnStrictState, error) {
+ if len(payload) == 0 {
+ return nil, nil
+ }
+ nonInputComparable, nonInputErr := normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(payload)
+ if nonInputErr != nil {
+ return nil, nonInputErr
+ }
+ return &openAIWSIngressPreviousTurnStrictState{
+ nonInputComparable: nonInputComparable,
+ }, nil
+}
+
+func shouldKeepIngressPreviousResponseIDWithStrictState(
+ previousState *openAIWSIngressPreviousTurnStrictState,
+ currentPayload []byte,
+ lastTurnResponseID string,
+ hasFunctionCallOutput bool,
+) (bool, string, error) {
+ if hasFunctionCallOutput {
+ return true, "has_function_call_output", nil
+ }
+ currentPreviousResponseID := strings.TrimSpace(openAIWSPayloadStringFromRaw(currentPayload, "previous_response_id"))
+ if currentPreviousResponseID == "" {
+ return false, "missing_previous_response_id", nil
+ }
+ expectedPreviousResponseID := strings.TrimSpace(lastTurnResponseID)
+ if expectedPreviousResponseID == "" {
+ return false, "missing_last_turn_response_id", nil
+ }
+ if currentPreviousResponseID != expectedPreviousResponseID {
+ return false, "previous_response_id_mismatch", nil
+ }
+ if previousState == nil {
+ return false, "missing_previous_turn_payload", nil
+ }
+
+ currentComparable, currentComparableErr := normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(currentPayload)
+ if currentComparableErr != nil {
+ return false, "non_input_compare_error", currentComparableErr
+ }
+ if !bytes.Equal(previousState.nonInputComparable, currentComparable) {
+ return false, "non_input_changed", nil
+ }
+ return true, "strict_incremental_ok", nil
+}
+
+func (s *OpenAIGatewayService) forwardOpenAIWSV2(
+ ctx context.Context,
+ c *gin.Context,
+ account *Account,
+ reqBody map[string]any,
+ token string,
+ decision OpenAIWSProtocolDecision,
+ isCodexCLI bool,
+ reqStream bool,
+ originalModel string,
+ mappedModel string,
+ startTime time.Time,
+ attempt int,
+ lastFailureReason string,
+) (*OpenAIForwardResult, error) {
+ if s == nil || account == nil {
+ return nil, wrapOpenAIWSFallback("invalid_state", errors.New("service or account is nil"))
+ }
+
+ wsURL, err := s.buildOpenAIResponsesWSURL(account)
+ if err != nil {
+ return nil, wrapOpenAIWSFallback("build_ws_url", err)
+ }
+ wsHost := "-"
+ wsPath := "-"
+ if parsed, parseErr := url.Parse(wsURL); parseErr == nil && parsed != nil {
+ if h := strings.TrimSpace(parsed.Host); h != "" {
+ wsHost = normalizeOpenAIWSLogValue(h)
+ }
+ if p := strings.TrimSpace(parsed.Path); p != "" {
+ wsPath = normalizeOpenAIWSLogValue(p)
+ }
+ }
+ logOpenAIWSModeDebug(
+ "dial_target account_id=%d account_type=%s ws_host=%s ws_path=%s",
+ account.ID,
+ account.Type,
+ wsHost,
+ wsPath,
+ )
+
+ payload := s.buildOpenAIWSCreatePayload(reqBody, account)
+ payloadStrategy, removedKeys := applyOpenAIWSRetryPayloadStrategy(payload, attempt)
+ previousResponseID := openAIWSPayloadString(payload, "previous_response_id")
+ previousResponseIDKind := ClassifyOpenAIPreviousResponseIDKind(previousResponseID)
+ promptCacheKey := openAIWSPayloadString(payload, "prompt_cache_key")
+ _, hasTools := payload["tools"]
+ debugEnabled := isOpenAIWSModeDebugEnabled()
+ payloadBytes := -1
+ resolvePayloadBytes := func() int {
+ if payloadBytes >= 0 {
+ return payloadBytes
+ }
+ payloadBytes = len(payloadAsJSONBytes(payload))
+ return payloadBytes
+ }
+ streamValue := "-"
+ if raw, ok := payload["stream"]; ok {
+ streamValue = normalizeOpenAIWSLogValue(strings.TrimSpace(fmt.Sprintf("%v", raw)))
+ }
+ turnState := ""
+ turnMetadata := ""
+ if c != nil && c.Request != nil {
+ turnState = strings.TrimSpace(c.GetHeader(openAIWSTurnStateHeader))
+ turnMetadata = strings.TrimSpace(c.GetHeader(openAIWSTurnMetadataHeader))
+ }
+ setOpenAIWSTurnMetadata(payload, turnMetadata)
+ payloadEventType := openAIWSPayloadString(payload, "type")
+ if payloadEventType == "" {
+ payloadEventType = "response.create"
+ }
+ if s.shouldEmitOpenAIWSPayloadSchema(attempt) {
+ logOpenAIWSModeInfo(
+ "[debug] payload_schema account_id=%d attempt=%d event=%s payload_keys=%s payload_bytes=%d payload_key_sizes=%s input_summary=%s stream=%s payload_strategy=%s removed_keys=%s has_previous_response_id=%v has_prompt_cache_key=%v has_tools=%v",
+ account.ID,
+ attempt,
+ payloadEventType,
+ normalizeOpenAIWSLogValue(strings.Join(sortedKeys(payload), ",")),
+ resolvePayloadBytes(),
+ normalizeOpenAIWSLogValue(summarizeOpenAIWSPayloadKeySizes(payload, openAIWSPayloadKeySizeTopN)),
+ normalizeOpenAIWSLogValue(summarizeOpenAIWSInput(payload["input"])),
+ streamValue,
+ normalizeOpenAIWSLogValue(payloadStrategy),
+ normalizeOpenAIWSLogValue(strings.Join(removedKeys, ",")),
+ previousResponseID != "",
+ promptCacheKey != "",
+ hasTools,
+ )
+ }
+
+ stateStore := s.getOpenAIWSStateStore()
+ groupID := getOpenAIGroupIDFromContext(c)
+ sessionHash := s.GenerateSessionHash(c, nil)
+ if sessionHash == "" {
+ var legacySessionHash string
+ sessionHash, legacySessionHash = openAIWSSessionHashesFromID(promptCacheKey)
+ attachOpenAILegacySessionHashToGin(c, legacySessionHash)
+ }
+ if turnState == "" && stateStore != nil && sessionHash != "" {
+ if savedTurnState, ok := stateStore.GetSessionTurnState(groupID, sessionHash); ok {
+ turnState = savedTurnState
+ }
+ }
+ preferredConnID := ""
+ if stateStore != nil && previousResponseID != "" {
+ if connID, ok := stateStore.GetResponseConn(previousResponseID); ok {
+ preferredConnID = connID
+ }
+ }
+ storeDisabled := s.isOpenAIWSStoreDisabledInRequest(reqBody, account)
+ if stateStore != nil && storeDisabled && previousResponseID == "" && sessionHash != "" {
+ if connID, ok := stateStore.GetSessionConn(groupID, sessionHash); ok {
+ preferredConnID = connID
+ }
+ }
+ storeDisabledConnMode := s.openAIWSStoreDisabledConnMode()
+ forceNewConnByPolicy := shouldForceNewConnOnStoreDisabled(storeDisabledConnMode, lastFailureReason)
+ forceNewConn := forceNewConnByPolicy && storeDisabled && previousResponseID == "" && sessionHash != "" && preferredConnID == ""
+ wsHeaders, sessionResolution := s.buildOpenAIWSHeaders(c, account, token, decision, isCodexCLI, turnState, turnMetadata, promptCacheKey)
+ logOpenAIWSModeDebug(
+ "acquire_start account_id=%d account_type=%s transport=%s preferred_conn_id=%s has_previous_response_id=%v session_hash=%s has_turn_state=%v turn_state_len=%d has_turn_metadata=%v turn_metadata_len=%d store_disabled=%v store_disabled_conn_mode=%s retry_last_reason=%s force_new_conn=%v header_user_agent=%s header_openai_beta=%s header_originator=%s header_accept_language=%s header_session_id=%s header_conversation_id=%s session_id_source=%s conversation_id_source=%s has_prompt_cache_key=%v has_chatgpt_account_id=%v has_authorization=%v has_session_id=%v has_conversation_id=%v proxy_enabled=%v",
+ account.ID,
+ account.Type,
+ normalizeOpenAIWSLogValue(string(decision.Transport)),
+ truncateOpenAIWSLogValue(preferredConnID, openAIWSIDValueMaxLen),
+ previousResponseID != "",
+ truncateOpenAIWSLogValue(sessionHash, 12),
+ turnState != "",
+ len(turnState),
+ turnMetadata != "",
+ len(turnMetadata),
+ storeDisabled,
+ normalizeOpenAIWSLogValue(storeDisabledConnMode),
+ truncateOpenAIWSLogValue(lastFailureReason, openAIWSLogValueMaxLen),
+ forceNewConn,
+ openAIWSHeaderValueForLog(wsHeaders, "user-agent"),
+ openAIWSHeaderValueForLog(wsHeaders, "openai-beta"),
+ openAIWSHeaderValueForLog(wsHeaders, "originator"),
+ openAIWSHeaderValueForLog(wsHeaders, "accept-language"),
+ openAIWSHeaderValueForLog(wsHeaders, "session_id"),
+ openAIWSHeaderValueForLog(wsHeaders, "conversation_id"),
+ normalizeOpenAIWSLogValue(sessionResolution.SessionSource),
+ normalizeOpenAIWSLogValue(sessionResolution.ConversationSource),
+ promptCacheKey != "",
+ hasOpenAIWSHeader(wsHeaders, "chatgpt-account-id"),
+ hasOpenAIWSHeader(wsHeaders, "authorization"),
+ hasOpenAIWSHeader(wsHeaders, "session_id"),
+ hasOpenAIWSHeader(wsHeaders, "conversation_id"),
+ account.ProxyID != nil && account.Proxy != nil,
+ )
+
+ acquireCtx, acquireCancel := context.WithTimeout(ctx, s.openAIWSAcquireTimeout())
+ defer acquireCancel()
+
+ lease, err := s.getOpenAIWSConnPool().Acquire(acquireCtx, openAIWSAcquireRequest{
+ Account: account,
+ WSURL: wsURL,
+ Headers: wsHeaders,
+ PreferredConnID: preferredConnID,
+ ForceNewConn: forceNewConn,
+ ProxyURL: func() string {
+ if account.ProxyID != nil && account.Proxy != nil {
+ return account.Proxy.URL()
+ }
+ return ""
+ }(),
+ })
+ if err != nil {
+ dialStatus, dialClass, dialCloseStatus, dialCloseReason, dialRespServer, dialRespVia, dialRespCFRay, dialRespReqID := summarizeOpenAIWSDialError(err)
+ logOpenAIWSModeInfo(
+ "acquire_fail account_id=%d account_type=%s transport=%s reason=%s dial_status=%d dial_class=%s dial_close_status=%s dial_close_reason=%s dial_resp_server=%s dial_resp_via=%s dial_resp_cf_ray=%s dial_resp_x_request_id=%s cause=%s preferred_conn_id=%s force_new_conn=%v ws_host=%s ws_path=%s proxy_enabled=%v",
+ account.ID,
+ account.Type,
+ normalizeOpenAIWSLogValue(string(decision.Transport)),
+ normalizeOpenAIWSLogValue(classifyOpenAIWSAcquireError(err)),
+ dialStatus,
+ dialClass,
+ dialCloseStatus,
+ truncateOpenAIWSLogValue(dialCloseReason, openAIWSHeaderValueMaxLen),
+ dialRespServer,
+ dialRespVia,
+ dialRespCFRay,
+ dialRespReqID,
+ truncateOpenAIWSLogValue(err.Error(), openAIWSLogValueMaxLen),
+ truncateOpenAIWSLogValue(preferredConnID, openAIWSIDValueMaxLen),
+ forceNewConn,
+ wsHost,
+ wsPath,
+ account.ProxyID != nil && account.Proxy != nil,
+ )
+ return nil, wrapOpenAIWSFallback(classifyOpenAIWSAcquireError(err), err)
+ }
+ defer lease.Release()
+ connID := strings.TrimSpace(lease.ConnID())
+ logOpenAIWSModeDebug(
+ "connected account_id=%d account_type=%s transport=%s conn_id=%s conn_reused=%v conn_pick_ms=%d queue_wait_ms=%d has_previous_response_id=%v",
+ account.ID,
+ account.Type,
+ normalizeOpenAIWSLogValue(string(decision.Transport)),
+ connID,
+ lease.Reused(),
+ lease.ConnPickDuration().Milliseconds(),
+ lease.QueueWaitDuration().Milliseconds(),
+ previousResponseID != "",
+ )
+ if previousResponseID != "" {
+ logOpenAIWSModeInfo(
+ "continuation_probe account_id=%d account_type=%s conn_id=%s previous_response_id=%s previous_response_id_kind=%s preferred_conn_id=%s conn_reused=%v store_disabled=%v session_hash=%s header_session_id=%s header_conversation_id=%s session_id_source=%s conversation_id_source=%s has_turn_state=%v turn_state_len=%d has_prompt_cache_key=%v",
+ account.ID,
+ account.Type,
+ truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
+ truncateOpenAIWSLogValue(previousResponseID, openAIWSIDValueMaxLen),
+ normalizeOpenAIWSLogValue(previousResponseIDKind),
+ truncateOpenAIWSLogValue(preferredConnID, openAIWSIDValueMaxLen),
+ lease.Reused(),
+ storeDisabled,
+ truncateOpenAIWSLogValue(sessionHash, 12),
+ openAIWSHeaderValueForLog(wsHeaders, "session_id"),
+ openAIWSHeaderValueForLog(wsHeaders, "conversation_id"),
+ normalizeOpenAIWSLogValue(sessionResolution.SessionSource),
+ normalizeOpenAIWSLogValue(sessionResolution.ConversationSource),
+ turnState != "",
+ len(turnState),
+ promptCacheKey != "",
+ )
+ }
+ if c != nil {
+ SetOpsLatencyMs(c, OpsOpenAIWSConnPickMsKey, lease.ConnPickDuration().Milliseconds())
+ SetOpsLatencyMs(c, OpsOpenAIWSQueueWaitMsKey, lease.QueueWaitDuration().Milliseconds())
+ c.Set(OpsOpenAIWSConnReusedKey, lease.Reused())
+ if connID != "" {
+ c.Set(OpsOpenAIWSConnIDKey, connID)
+ }
+ }
+
+ handshakeTurnState := strings.TrimSpace(lease.HandshakeHeader(openAIWSTurnStateHeader))
+ logOpenAIWSModeDebug(
+ "handshake account_id=%d conn_id=%s has_turn_state=%v turn_state_len=%d",
+ account.ID,
+ connID,
+ handshakeTurnState != "",
+ len(handshakeTurnState),
+ )
+ if handshakeTurnState != "" {
+ if stateStore != nil && sessionHash != "" {
+ stateStore.BindSessionTurnState(groupID, sessionHash, handshakeTurnState, s.openAIWSSessionStickyTTL())
+ }
+ if c != nil {
+ c.Header(http.CanonicalHeaderKey(openAIWSTurnStateHeader), handshakeTurnState)
+ }
+ }
+
+ if err := s.performOpenAIWSGeneratePrewarm(
+ ctx,
+ lease,
+ decision,
+ payload,
+ previousResponseID,
+ reqBody,
+ account,
+ stateStore,
+ groupID,
+ ); err != nil {
+ return nil, err
+ }
+
+ if err := lease.WriteJSONWithContextTimeout(ctx, payload, s.openAIWSWriteTimeout()); err != nil {
+ lease.MarkBroken()
+ logOpenAIWSModeInfo(
+ "write_request_fail account_id=%d conn_id=%s cause=%s payload_bytes=%d",
+ account.ID,
+ connID,
+ truncateOpenAIWSLogValue(err.Error(), openAIWSLogValueMaxLen),
+ resolvePayloadBytes(),
+ )
+ return nil, wrapOpenAIWSFallback("write_request", err)
+ }
+ if debugEnabled {
+ logOpenAIWSModeDebug(
+ "write_request_sent account_id=%d conn_id=%s stream=%v payload_bytes=%d previous_response_id=%s",
+ account.ID,
+ connID,
+ reqStream,
+ resolvePayloadBytes(),
+ truncateOpenAIWSLogValue(previousResponseID, openAIWSIDValueMaxLen),
+ )
+ }
+
+ usage := &OpenAIUsage{}
+ var firstTokenMs *int
+ responseID := ""
+ var finalResponse []byte
+ wroteDownstream := false
+ needModelReplace := originalModel != mappedModel
+ var mappedModelBytes []byte
+ if needModelReplace && mappedModel != "" {
+ mappedModelBytes = []byte(mappedModel)
+ }
+ bufferedStreamEvents := make([][]byte, 0, 4)
+ eventCount := 0
+ tokenEventCount := 0
+ terminalEventCount := 0
+ bufferedEventCount := 0
+ flushedBufferedEventCount := 0
+ firstEventType := ""
+ lastEventType := ""
+
+ var flusher http.Flusher
+ if reqStream {
+ if s.responseHeaderFilter != nil {
+ responseheaders.WriteFilteredHeaders(c.Writer.Header(), http.Header{}, s.responseHeaderFilter)
+ }
+ c.Header("Content-Type", "text/event-stream")
+ c.Header("Cache-Control", "no-cache")
+ c.Header("Connection", "keep-alive")
+ c.Header("X-Accel-Buffering", "no")
+ f, ok := c.Writer.(http.Flusher)
+ if !ok {
+ lease.MarkBroken()
+ return nil, wrapOpenAIWSFallback("streaming_not_supported", errors.New("streaming not supported"))
+ }
+ flusher = f
+ }
+
+ clientDisconnected := false
+ flushBatchSize := s.openAIWSEventFlushBatchSize()
+ flushInterval := s.openAIWSEventFlushInterval()
+ pendingFlushEvents := 0
+ lastFlushAt := time.Now()
+ flushStreamWriter := func(force bool) {
+ if clientDisconnected || flusher == nil || pendingFlushEvents <= 0 {
+ return
+ }
+ if !force && flushBatchSize > 1 && pendingFlushEvents < flushBatchSize {
+ if flushInterval <= 0 || time.Since(lastFlushAt) < flushInterval {
+ return
+ }
+ }
+ flusher.Flush()
+ pendingFlushEvents = 0
+ lastFlushAt = time.Now()
+ }
+ emitStreamMessage := func(message []byte, forceFlush bool) {
+ if clientDisconnected {
+ return
+ }
+ frame := make([]byte, 0, len(message)+8)
+ frame = append(frame, "data: "...)
+ frame = append(frame, message...)
+ frame = append(frame, '\n', '\n')
+ _, wErr := c.Writer.Write(frame)
+ if wErr == nil {
+ wroteDownstream = true
+ pendingFlushEvents++
+ flushStreamWriter(forceFlush)
+ return
+ }
+ clientDisconnected = true
+ logger.LegacyPrintf("service.openai_gateway", "[OpenAI WS Mode] client disconnected, continue draining upstream: account=%d", account.ID)
+ }
+ flushBufferedStreamEvents := func(reason string) {
+ if len(bufferedStreamEvents) == 0 {
+ return
+ }
+ flushed := len(bufferedStreamEvents)
+ for _, buffered := range bufferedStreamEvents {
+ emitStreamMessage(buffered, false)
+ }
+ bufferedStreamEvents = bufferedStreamEvents[:0]
+ flushStreamWriter(true)
+ flushedBufferedEventCount += flushed
+ if debugEnabled {
+ logOpenAIWSModeDebug(
+ "buffer_flush account_id=%d conn_id=%s reason=%s flushed=%d total_flushed=%d client_disconnected=%v",
+ account.ID,
+ connID,
+ truncateOpenAIWSLogValue(reason, openAIWSLogValueMaxLen),
+ flushed,
+ flushedBufferedEventCount,
+ clientDisconnected,
+ )
+ }
+ }
+
+ readTimeout := s.openAIWSReadTimeout()
+
+ for {
+ message, readErr := lease.ReadMessageWithContextTimeout(ctx, readTimeout)
+ if readErr != nil {
+ lease.MarkBroken()
+ closeStatus, closeReason := summarizeOpenAIWSReadCloseError(readErr)
+ logOpenAIWSModeInfo(
+ "read_fail account_id=%d conn_id=%s wrote_downstream=%v close_status=%s close_reason=%s cause=%s events=%d token_events=%d terminal_events=%d buffered_pending=%d buffered_flushed=%d first_event=%s last_event=%s",
+ account.ID,
+ connID,
+ wroteDownstream,
+ closeStatus,
+ closeReason,
+ truncateOpenAIWSLogValue(readErr.Error(), openAIWSLogValueMaxLen),
+ eventCount,
+ tokenEventCount,
+ terminalEventCount,
+ len(bufferedStreamEvents),
+ flushedBufferedEventCount,
+ truncateOpenAIWSLogValue(firstEventType, openAIWSLogValueMaxLen),
+ truncateOpenAIWSLogValue(lastEventType, openAIWSLogValueMaxLen),
+ )
+ if !wroteDownstream {
+ return nil, wrapOpenAIWSFallback(classifyOpenAIWSReadFallbackReason(readErr), readErr)
+ }
+ if clientDisconnected {
+ break
+ }
+ setOpsUpstreamError(c, 0, sanitizeUpstreamErrorMessage(readErr.Error()), "")
+ return nil, fmt.Errorf("openai ws read event: %w", readErr)
+ }
+
+ eventType, eventResponseID, responseField := parseOpenAIWSEventEnvelope(message)
+ if eventType == "" {
+ continue
+ }
+ eventCount++
+ if firstEventType == "" {
+ firstEventType = eventType
+ }
+ lastEventType = eventType
+
+ if responseID == "" && eventResponseID != "" {
+ responseID = eventResponseID
+ }
+
+ isTokenEvent := isOpenAIWSTokenEvent(eventType)
+ if isTokenEvent {
+ tokenEventCount++
+ }
+ isTerminalEvent := isOpenAIWSTerminalEvent(eventType)
+ if isTerminalEvent {
+ terminalEventCount++
+ }
+ if firstTokenMs == nil && isTokenEvent {
+ ms := int(time.Since(startTime).Milliseconds())
+ firstTokenMs = &ms
+ }
+ if debugEnabled && shouldLogOpenAIWSEvent(eventCount, eventType) {
+ logOpenAIWSModeDebug(
+ "event_received account_id=%d conn_id=%s idx=%d type=%s bytes=%d token=%v terminal=%v buffered_pending=%d",
+ account.ID,
+ connID,
+ eventCount,
+ truncateOpenAIWSLogValue(eventType, openAIWSLogValueMaxLen),
+ len(message),
+ isTokenEvent,
+ isTerminalEvent,
+ len(bufferedStreamEvents),
+ )
+ }
+
+ if !clientDisconnected {
+ if needModelReplace && len(mappedModelBytes) > 0 && openAIWSEventMayContainModel(eventType) && bytes.Contains(message, mappedModelBytes) {
+ message = replaceOpenAIWSMessageModel(message, mappedModel, originalModel)
+ }
+ if openAIWSEventMayContainToolCalls(eventType) && openAIWSMessageLikelyContainsToolCalls(message) {
+ if corrected, changed := s.toolCorrector.CorrectToolCallsInSSEBytes(message); changed {
+ message = corrected
+ }
+ }
+ }
+ if openAIWSEventShouldParseUsage(eventType) {
+ parseOpenAIWSResponseUsageFromCompletedEvent(message, usage)
+ }
+
+ if eventType == "error" {
+ errCodeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(message)
+ errMsg := strings.TrimSpace(errMsgRaw)
+ if errMsg == "" {
+ errMsg = "Upstream websocket error"
+ }
+ fallbackReason, canFallback := classifyOpenAIWSErrorEventFromRaw(errCodeRaw, errTypeRaw, errMsgRaw)
+ errCode, errType, errMessage := summarizeOpenAIWSErrorEventFieldsFromRaw(errCodeRaw, errTypeRaw, errMsgRaw)
+ logOpenAIWSModeInfo(
+ "error_event account_id=%d conn_id=%s idx=%d fallback_reason=%s can_fallback=%v err_code=%s err_type=%s err_message=%s",
+ account.ID,
+ connID,
+ eventCount,
+ truncateOpenAIWSLogValue(fallbackReason, openAIWSLogValueMaxLen),
+ canFallback,
+ errCode,
+ errType,
+ errMessage,
+ )
+ if fallbackReason == "previous_response_not_found" {
+ logOpenAIWSModeInfo(
+ "previous_response_not_found_diag account_id=%d account_type=%s conn_id=%s previous_response_id=%s previous_response_id_kind=%s response_id=%s event_idx=%d req_stream=%v store_disabled=%v conn_reused=%v session_hash=%s header_session_id=%s header_conversation_id=%s session_id_source=%s conversation_id_source=%s has_turn_state=%v turn_state_len=%d has_prompt_cache_key=%v err_code=%s err_type=%s err_message=%s",
+ account.ID,
+ account.Type,
+ connID,
+ truncateOpenAIWSLogValue(previousResponseID, openAIWSIDValueMaxLen),
+ normalizeOpenAIWSLogValue(previousResponseIDKind),
+ truncateOpenAIWSLogValue(responseID, openAIWSIDValueMaxLen),
+ eventCount,
+ reqStream,
+ storeDisabled,
+ lease.Reused(),
+ truncateOpenAIWSLogValue(sessionHash, 12),
+ openAIWSHeaderValueForLog(wsHeaders, "session_id"),
+ openAIWSHeaderValueForLog(wsHeaders, "conversation_id"),
+ normalizeOpenAIWSLogValue(sessionResolution.SessionSource),
+ normalizeOpenAIWSLogValue(sessionResolution.ConversationSource),
+ turnState != "",
+ len(turnState),
+ promptCacheKey != "",
+ errCode,
+ errType,
+ errMessage,
+ )
+ }
+ // error 事件后连接不再可复用,避免回池后污染下一请求。
+ lease.MarkBroken()
+ if !wroteDownstream && canFallback {
+ return nil, wrapOpenAIWSFallback(fallbackReason, errors.New(errMsg))
+ }
+ statusCode := openAIWSErrorHTTPStatusFromRaw(errCodeRaw, errTypeRaw)
+ setOpsUpstreamError(c, statusCode, errMsg, "")
+ if reqStream && !clientDisconnected {
+ flushBufferedStreamEvents("error_event")
+ emitStreamMessage(message, true)
+ }
+ if !reqStream {
+ c.JSON(statusCode, gin.H{
+ "error": gin.H{
+ "type": "upstream_error",
+ "message": errMsg,
+ },
+ })
+ }
+ return nil, fmt.Errorf("openai ws error event: %s", errMsg)
+ }
+
+ if reqStream {
+ // 在首个 token 前先缓冲事件(如 response.created),
+ // 以便上游早期断连时仍可安全回退到 HTTP,不给下游发送半截流。
+ shouldBuffer := firstTokenMs == nil && !isTokenEvent && !isTerminalEvent
+ if shouldBuffer {
+ buffered := make([]byte, len(message))
+ copy(buffered, message)
+ bufferedStreamEvents = append(bufferedStreamEvents, buffered)
+ bufferedEventCount++
+ if debugEnabled && shouldLogOpenAIWSBufferedEvent(bufferedEventCount) {
+ logOpenAIWSModeDebug(
+ "buffer_enqueue account_id=%d conn_id=%s idx=%d event_idx=%d event_type=%s buffer_size=%d",
+ account.ID,
+ connID,
+ bufferedEventCount,
+ eventCount,
+ truncateOpenAIWSLogValue(eventType, openAIWSLogValueMaxLen),
+ len(bufferedStreamEvents),
+ )
+ }
+ } else {
+ flushBufferedStreamEvents(eventType)
+ emitStreamMessage(message, isTerminalEvent)
+ }
+ } else {
+ if responseField.Exists() && responseField.Type == gjson.JSON {
+ finalResponse = []byte(responseField.Raw)
+ }
+ }
+
+ if isTerminalEvent {
+ break
+ }
+ }
+
+ if !reqStream {
+ if len(finalResponse) == 0 {
+ logOpenAIWSModeInfo(
+ "missing_final_response account_id=%d conn_id=%s events=%d token_events=%d terminal_events=%d wrote_downstream=%v",
+ account.ID,
+ connID,
+ eventCount,
+ tokenEventCount,
+ terminalEventCount,
+ wroteDownstream,
+ )
+ if !wroteDownstream {
+ return nil, wrapOpenAIWSFallback("missing_final_response", errors.New("no terminal response payload"))
+ }
+ return nil, errors.New("ws finished without final response")
+ }
+
+ if needModelReplace {
+ finalResponse = s.replaceModelInResponseBody(finalResponse, mappedModel, originalModel)
+ }
+ finalResponse = s.correctToolCallsInResponseBody(finalResponse)
+ populateOpenAIUsageFromResponseJSON(finalResponse, usage)
+ if responseID == "" {
+ responseID = strings.TrimSpace(gjson.GetBytes(finalResponse, "id").String())
+ }
+
+ c.Data(http.StatusOK, "application/json", finalResponse)
+ } else {
+ flushStreamWriter(true)
+ }
+
+ if responseID != "" && stateStore != nil {
+ ttl := s.openAIWSResponseStickyTTL()
+ logOpenAIWSBindResponseAccountWarn(groupID, account.ID, responseID, stateStore.BindResponseAccount(ctx, groupID, responseID, account.ID, ttl))
+ stateStore.BindResponseConn(responseID, lease.ConnID(), ttl)
+ }
+ if stateStore != nil && storeDisabled && sessionHash != "" {
+ stateStore.BindSessionConn(groupID, sessionHash, lease.ConnID(), s.openAIWSSessionStickyTTL())
+ }
+ firstTokenMsValue := -1
+ if firstTokenMs != nil {
+ firstTokenMsValue = *firstTokenMs
+ }
+ logOpenAIWSModeDebug(
+ "completed account_id=%d conn_id=%s response_id=%s stream=%v duration_ms=%d events=%d token_events=%d terminal_events=%d buffered_events=%d buffered_flushed=%d first_event=%s last_event=%s first_token_ms=%d wrote_downstream=%v client_disconnected=%v",
+ account.ID,
+ connID,
+ truncateOpenAIWSLogValue(strings.TrimSpace(responseID), openAIWSIDValueMaxLen),
+ reqStream,
+ time.Since(startTime).Milliseconds(),
+ eventCount,
+ tokenEventCount,
+ terminalEventCount,
+ bufferedEventCount,
+ flushedBufferedEventCount,
+ truncateOpenAIWSLogValue(firstEventType, openAIWSLogValueMaxLen),
+ truncateOpenAIWSLogValue(lastEventType, openAIWSLogValueMaxLen),
+ firstTokenMsValue,
+ wroteDownstream,
+ clientDisconnected,
+ )
+
+ return &OpenAIForwardResult{
+ RequestID: responseID,
+ Usage: *usage,
+ Model: originalModel,
+ ReasoningEffort: extractOpenAIReasoningEffort(reqBody, originalModel),
+ Stream: reqStream,
+ OpenAIWSMode: true,
+ Duration: time.Since(startTime),
+ FirstTokenMs: firstTokenMs,
+ }, nil
+}
+
+// ProxyResponsesWebSocketFromClient 处理客户端入站 WebSocket(OpenAI Responses WS Mode)并转发到上游。
+// 当前实现按“单请求 -> 终止事件 -> 下一请求”的顺序代理,适配 Codex CLI 的 turn 模式。
+func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
+ ctx context.Context,
+ c *gin.Context,
+ clientConn *coderws.Conn,
+ account *Account,
+ token string,
+ firstClientMessage []byte,
+ hooks *OpenAIWSIngressHooks,
+) error {
+ if s == nil {
+ return errors.New("service is nil")
+ }
+ if c == nil {
+ return errors.New("gin context is nil")
+ }
+ if clientConn == nil {
+ return errors.New("client websocket is nil")
+ }
+ if account == nil {
+ return errors.New("account is nil")
+ }
+ if strings.TrimSpace(token) == "" {
+ return errors.New("token is empty")
+ }
+
+ wsDecision := s.getOpenAIWSProtocolResolver().Resolve(account)
+ modeRouterV2Enabled := s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.ModeRouterV2Enabled
+ ingressMode := OpenAIWSIngressModeShared
+ if modeRouterV2Enabled {
+ ingressMode = account.ResolveOpenAIResponsesWebSocketV2Mode(s.cfg.Gateway.OpenAIWS.IngressModeDefault)
+ if ingressMode == OpenAIWSIngressModeOff {
+ return NewOpenAIWSClientCloseError(
+ coderws.StatusPolicyViolation,
+ "websocket mode is disabled for this account",
+ nil,
+ )
+ }
+ }
+ if wsDecision.Transport != OpenAIUpstreamTransportResponsesWebsocketV2 {
+ return fmt.Errorf("websocket ingress requires ws_v2 transport, got=%s", wsDecision.Transport)
+ }
+ dedicatedMode := modeRouterV2Enabled && ingressMode == OpenAIWSIngressModeDedicated
+
+ wsURL, err := s.buildOpenAIResponsesWSURL(account)
+ if err != nil {
+ return fmt.Errorf("build ws url: %w", err)
+ }
+ wsHost := "-"
+ wsPath := "-"
+ if parsedURL, parseErr := url.Parse(wsURL); parseErr == nil && parsedURL != nil {
+ wsHost = normalizeOpenAIWSLogValue(parsedURL.Host)
+ wsPath = normalizeOpenAIWSLogValue(parsedURL.Path)
+ }
+ debugEnabled := isOpenAIWSModeDebugEnabled()
+
+ type openAIWSClientPayload struct {
+ payloadRaw []byte
+ rawForHash []byte
+ promptCacheKey string
+ previousResponseID string
+ originalModel string
+ payloadBytes int
+ }
+
+ applyPayloadMutation := func(current []byte, path string, value any) ([]byte, error) {
+ next, err := sjson.SetBytes(current, path, value)
+ if err == nil {
+ return next, nil
+ }
+
+ // 仅在确实需要修改 payload 且 sjson 失败时,退回 map 路径确保兼容性。
+ payload := make(map[string]any)
+ if unmarshalErr := json.Unmarshal(current, &payload); unmarshalErr != nil {
+ return nil, err
+ }
+ switch path {
+ case "type", "model":
+ payload[path] = value
+ case "client_metadata." + openAIWSTurnMetadataHeader:
+ setOpenAIWSTurnMetadata(payload, fmt.Sprintf("%v", value))
+ default:
+ return nil, err
+ }
+ rebuilt, marshalErr := json.Marshal(payload)
+ if marshalErr != nil {
+ return nil, marshalErr
+ }
+ return rebuilt, nil
+ }
+
+ parseClientPayload := func(raw []byte) (openAIWSClientPayload, error) {
+ trimmed := bytes.TrimSpace(raw)
+ if len(trimmed) == 0 {
+ return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "empty websocket request payload", nil)
+ }
+ if !gjson.ValidBytes(trimmed) {
+ return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", errors.New("invalid json"))
+ }
+
+ values := gjson.GetManyBytes(trimmed, "type", "model", "prompt_cache_key", "previous_response_id")
+ eventType := strings.TrimSpace(values[0].String())
+ normalized := trimmed
+ switch eventType {
+ case "":
+ eventType = "response.create"
+ next, setErr := applyPayloadMutation(normalized, "type", eventType)
+ if setErr != nil {
+ return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", setErr)
+ }
+ normalized = next
+ case "response.create":
+ case "response.append":
+ return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(
+ coderws.StatusPolicyViolation,
+ "response.append is not supported in ws v2; use response.create with previous_response_id",
+ nil,
+ )
+ default:
+ return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(
+ coderws.StatusPolicyViolation,
+ fmt.Sprintf("unsupported websocket request type: %s", eventType),
+ nil,
+ )
+ }
+
+ originalModel := strings.TrimSpace(values[1].String())
+ if originalModel == "" {
+ return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(
+ coderws.StatusPolicyViolation,
+ "model is required in response.create payload",
+ nil,
+ )
+ }
+ promptCacheKey := strings.TrimSpace(values[2].String())
+ previousResponseID := strings.TrimSpace(values[3].String())
+ previousResponseIDKind := ClassifyOpenAIPreviousResponseIDKind(previousResponseID)
+ if previousResponseID != "" && previousResponseIDKind == OpenAIPreviousResponseIDKindMessageID {
+ return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(
+ coderws.StatusPolicyViolation,
+ "previous_response_id must be a response.id (resp_*), not a message id",
+ nil,
+ )
+ }
+ if turnMetadata := strings.TrimSpace(c.GetHeader(openAIWSTurnMetadataHeader)); turnMetadata != "" {
+ next, setErr := applyPayloadMutation(normalized, "client_metadata."+openAIWSTurnMetadataHeader, turnMetadata)
+ if setErr != nil {
+ return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", setErr)
+ }
+ normalized = next
+ }
+ mappedModel := account.GetMappedModel(originalModel)
+ if normalizedModel := normalizeCodexModel(mappedModel); normalizedModel != "" {
+ mappedModel = normalizedModel
+ }
+ if mappedModel != originalModel {
+ next, setErr := applyPayloadMutation(normalized, "model", mappedModel)
+ if setErr != nil {
+ return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", setErr)
+ }
+ normalized = next
+ }
+
+ return openAIWSClientPayload{
+ payloadRaw: normalized,
+ rawForHash: trimmed,
+ promptCacheKey: promptCacheKey,
+ previousResponseID: previousResponseID,
+ originalModel: originalModel,
+ payloadBytes: len(normalized),
+ }, nil
+ }
+
+ firstPayload, err := parseClientPayload(firstClientMessage)
+ if err != nil {
+ return err
+ }
+
+ turnState := strings.TrimSpace(c.GetHeader(openAIWSTurnStateHeader))
+ stateStore := s.getOpenAIWSStateStore()
+ groupID := getOpenAIGroupIDFromContext(c)
+ sessionHash := s.GenerateSessionHash(c, firstPayload.rawForHash)
+ if turnState == "" && stateStore != nil && sessionHash != "" {
+ if savedTurnState, ok := stateStore.GetSessionTurnState(groupID, sessionHash); ok {
+ turnState = savedTurnState
+ }
+ }
+
+ preferredConnID := ""
+ if stateStore != nil && firstPayload.previousResponseID != "" {
+ if connID, ok := stateStore.GetResponseConn(firstPayload.previousResponseID); ok {
+ preferredConnID = connID
+ }
+ }
+
+ storeDisabled := s.isOpenAIWSStoreDisabledInRequestRaw(firstPayload.payloadRaw, account)
+ storeDisabledConnMode := s.openAIWSStoreDisabledConnMode()
+ if stateStore != nil && storeDisabled && firstPayload.previousResponseID == "" && sessionHash != "" {
+ if connID, ok := stateStore.GetSessionConn(groupID, sessionHash); ok {
+ preferredConnID = connID
+ }
+ }
+
+ isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent")) || (s.cfg != nil && s.cfg.Gateway.ForceCodexCLI)
+ wsHeaders, _ := s.buildOpenAIWSHeaders(c, account, token, wsDecision, isCodexCLI, turnState, strings.TrimSpace(c.GetHeader(openAIWSTurnMetadataHeader)), firstPayload.promptCacheKey)
+ baseAcquireReq := openAIWSAcquireRequest{
+ Account: account,
+ WSURL: wsURL,
+ Headers: wsHeaders,
+ ProxyURL: func() string {
+ if account.ProxyID != nil && account.Proxy != nil {
+ return account.Proxy.URL()
+ }
+ return ""
+ }(),
+ ForceNewConn: false,
+ }
+ pool := s.getOpenAIWSConnPool()
+ if pool == nil {
+ return errors.New("openai ws conn pool is nil")
+ }
+
+ logOpenAIWSModeInfo(
+ "ingress_ws_protocol_confirm account_id=%d account_type=%s transport=%s ws_host=%s ws_path=%s ws_mode=%s store_disabled=%v has_session_hash=%v has_previous_response_id=%v",
+ account.ID,
+ account.Type,
+ normalizeOpenAIWSLogValue(string(wsDecision.Transport)),
+ wsHost,
+ wsPath,
+ normalizeOpenAIWSLogValue(ingressMode),
+ storeDisabled,
+ sessionHash != "",
+ firstPayload.previousResponseID != "",
+ )
+
+ if debugEnabled {
+ logOpenAIWSModeDebug(
+ "ingress_ws_start account_id=%d account_type=%s transport=%s ws_host=%s preferred_conn_id=%s has_session_hash=%v has_previous_response_id=%v store_disabled=%v",
+ account.ID,
+ account.Type,
+ normalizeOpenAIWSLogValue(string(wsDecision.Transport)),
+ wsHost,
+ truncateOpenAIWSLogValue(preferredConnID, openAIWSIDValueMaxLen),
+ sessionHash != "",
+ firstPayload.previousResponseID != "",
+ storeDisabled,
+ )
+ }
+ if firstPayload.previousResponseID != "" {
+ firstPreviousResponseIDKind := ClassifyOpenAIPreviousResponseIDKind(firstPayload.previousResponseID)
+ logOpenAIWSModeInfo(
+ "ingress_ws_continuation_probe account_id=%d turn=%d previous_response_id=%s previous_response_id_kind=%s preferred_conn_id=%s session_hash=%s header_session_id=%s header_conversation_id=%s has_turn_state=%v turn_state_len=%d has_prompt_cache_key=%v store_disabled=%v",
+ account.ID,
+ 1,
+ truncateOpenAIWSLogValue(firstPayload.previousResponseID, openAIWSIDValueMaxLen),
+ normalizeOpenAIWSLogValue(firstPreviousResponseIDKind),
+ truncateOpenAIWSLogValue(preferredConnID, openAIWSIDValueMaxLen),
+ truncateOpenAIWSLogValue(sessionHash, 12),
+ openAIWSHeaderValueForLog(baseAcquireReq.Headers, "session_id"),
+ openAIWSHeaderValueForLog(baseAcquireReq.Headers, "conversation_id"),
+ turnState != "",
+ len(turnState),
+ firstPayload.promptCacheKey != "",
+ storeDisabled,
+ )
+ }
+
+ acquireTimeout := s.openAIWSAcquireTimeout()
+ if acquireTimeout <= 0 {
+ acquireTimeout = 30 * time.Second
+ }
+
+ acquireTurnLease := func(turn int, preferred string, forcePreferredConn bool) (*openAIWSConnLease, error) {
+ req := cloneOpenAIWSAcquireRequest(baseAcquireReq)
+ req.PreferredConnID = strings.TrimSpace(preferred)
+ req.ForcePreferredConn = forcePreferredConn
+ // dedicated 模式下每次获取均新建连接,避免跨会话复用残留上下文。
+ req.ForceNewConn = dedicatedMode
+ acquireCtx, acquireCancel := context.WithTimeout(ctx, acquireTimeout)
+ lease, acquireErr := pool.Acquire(acquireCtx, req)
+ acquireCancel()
+ if acquireErr != nil {
+ dialStatus, dialClass, dialCloseStatus, dialCloseReason, dialRespServer, dialRespVia, dialRespCFRay, dialRespReqID := summarizeOpenAIWSDialError(acquireErr)
+ logOpenAIWSModeInfo(
+ "ingress_ws_upstream_acquire_fail account_id=%d turn=%d reason=%s dial_status=%d dial_class=%s dial_close_status=%s dial_close_reason=%s dial_resp_server=%s dial_resp_via=%s dial_resp_cf_ray=%s dial_resp_x_request_id=%s cause=%s preferred_conn_id=%s force_preferred_conn=%v ws_host=%s ws_path=%s proxy_enabled=%v",
+ account.ID,
+ turn,
+ normalizeOpenAIWSLogValue(classifyOpenAIWSAcquireError(acquireErr)),
+ dialStatus,
+ dialClass,
+ dialCloseStatus,
+ truncateOpenAIWSLogValue(dialCloseReason, openAIWSHeaderValueMaxLen),
+ dialRespServer,
+ dialRespVia,
+ dialRespCFRay,
+ dialRespReqID,
+ truncateOpenAIWSLogValue(acquireErr.Error(), openAIWSLogValueMaxLen),
+ truncateOpenAIWSLogValue(preferred, openAIWSIDValueMaxLen),
+ forcePreferredConn,
+ wsHost,
+ wsPath,
+ account.ProxyID != nil && account.Proxy != nil,
+ )
+ if errors.Is(acquireErr, errOpenAIWSPreferredConnUnavailable) {
+ return nil, NewOpenAIWSClientCloseError(
+ coderws.StatusPolicyViolation,
+ "upstream continuation connection is unavailable; please restart the conversation",
+ acquireErr,
+ )
+ }
+ if errors.Is(acquireErr, context.DeadlineExceeded) || errors.Is(acquireErr, errOpenAIWSConnQueueFull) {
+ return nil, NewOpenAIWSClientCloseError(
+ coderws.StatusTryAgainLater,
+ "upstream websocket is busy, please retry later",
+ acquireErr,
+ )
+ }
+ return nil, acquireErr
+ }
+ connID := strings.TrimSpace(lease.ConnID())
+ if handshakeTurnState := strings.TrimSpace(lease.HandshakeHeader(openAIWSTurnStateHeader)); handshakeTurnState != "" {
+ turnState = handshakeTurnState
+ if stateStore != nil && sessionHash != "" {
+ stateStore.BindSessionTurnState(groupID, sessionHash, handshakeTurnState, s.openAIWSSessionStickyTTL())
+ }
+ updatedHeaders := cloneHeader(baseAcquireReq.Headers)
+ if updatedHeaders == nil {
+ updatedHeaders = make(http.Header)
+ }
+ updatedHeaders.Set(openAIWSTurnStateHeader, handshakeTurnState)
+ baseAcquireReq.Headers = updatedHeaders
+ }
+ logOpenAIWSModeInfo(
+ "ingress_ws_upstream_connected account_id=%d turn=%d conn_id=%s conn_reused=%v conn_pick_ms=%d queue_wait_ms=%d preferred_conn_id=%s",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
+ lease.Reused(),
+ lease.ConnPickDuration().Milliseconds(),
+ lease.QueueWaitDuration().Milliseconds(),
+ truncateOpenAIWSLogValue(preferred, openAIWSIDValueMaxLen),
+ )
+ return lease, nil
+ }
+
+ writeClientMessage := func(message []byte) error {
+ writeCtx, cancel := context.WithTimeout(ctx, s.openAIWSWriteTimeout())
+ defer cancel()
+ return clientConn.Write(writeCtx, coderws.MessageText, message)
+ }
+
+ readClientMessage := func() ([]byte, error) {
+ msgType, payload, readErr := clientConn.Read(ctx)
+ if readErr != nil {
+ return nil, readErr
+ }
+ if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
+ return nil, NewOpenAIWSClientCloseError(
+ coderws.StatusPolicyViolation,
+ fmt.Sprintf("unsupported websocket client message type: %s", msgType.String()),
+ nil,
+ )
+ }
+ return payload, nil
+ }
+
+ sendAndRelay := func(turn int, lease *openAIWSConnLease, payload []byte, payloadBytes int, originalModel string) (*OpenAIForwardResult, error) {
+ if lease == nil {
+ return nil, errors.New("upstream websocket lease is nil")
+ }
+ turnStart := time.Now()
+ wroteDownstream := false
+ if err := lease.WriteJSONWithContextTimeout(ctx, json.RawMessage(payload), s.openAIWSWriteTimeout()); err != nil {
+ return nil, wrapOpenAIWSIngressTurnError(
+ "write_upstream",
+ fmt.Errorf("write upstream websocket request: %w", err),
+ false,
+ )
+ }
+ if debugEnabled {
+ logOpenAIWSModeDebug(
+ "ingress_ws_turn_request_sent account_id=%d turn=%d conn_id=%s payload_bytes=%d",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen),
+ payloadBytes,
+ )
+ }
+
+ responseID := ""
+ usage := OpenAIUsage{}
+ var firstTokenMs *int
+ reqStream := openAIWSPayloadBoolFromRaw(payload, "stream", true)
+ turnPreviousResponseID := openAIWSPayloadStringFromRaw(payload, "previous_response_id")
+ turnPreviousResponseIDKind := ClassifyOpenAIPreviousResponseIDKind(turnPreviousResponseID)
+ turnPromptCacheKey := openAIWSPayloadStringFromRaw(payload, "prompt_cache_key")
+ turnStoreDisabled := s.isOpenAIWSStoreDisabledInRequestRaw(payload, account)
+ turnHasFunctionCallOutput := gjson.GetBytes(payload, `input.#(type=="function_call_output")`).Exists()
+ eventCount := 0
+ tokenEventCount := 0
+ terminalEventCount := 0
+ firstEventType := ""
+ lastEventType := ""
+ needModelReplace := false
+ clientDisconnected := false
+ mappedModel := ""
+ var mappedModelBytes []byte
+ if originalModel != "" {
+ mappedModel = account.GetMappedModel(originalModel)
+ if normalizedModel := normalizeCodexModel(mappedModel); normalizedModel != "" {
+ mappedModel = normalizedModel
+ }
+ needModelReplace = mappedModel != "" && mappedModel != originalModel
+ if needModelReplace {
+ mappedModelBytes = []byte(mappedModel)
+ }
+ }
+ for {
+ upstreamMessage, readErr := lease.ReadMessageWithContextTimeout(ctx, s.openAIWSReadTimeout())
+ if readErr != nil {
+ lease.MarkBroken()
+ return nil, wrapOpenAIWSIngressTurnError(
+ "read_upstream",
+ fmt.Errorf("read upstream websocket event: %w", readErr),
+ wroteDownstream,
+ )
+ }
+
+ eventType, eventResponseID, _ := parseOpenAIWSEventEnvelope(upstreamMessage)
+ if responseID == "" && eventResponseID != "" {
+ responseID = eventResponseID
+ }
+ if eventType != "" {
+ eventCount++
+ if firstEventType == "" {
+ firstEventType = eventType
+ }
+ lastEventType = eventType
+ }
+ if eventType == "error" {
+ errCodeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(upstreamMessage)
+ fallbackReason, _ := classifyOpenAIWSErrorEventFromRaw(errCodeRaw, errTypeRaw, errMsgRaw)
+ errCode, errType, errMessage := summarizeOpenAIWSErrorEventFieldsFromRaw(errCodeRaw, errTypeRaw, errMsgRaw)
+ recoverablePrevNotFound := fallbackReason == openAIWSIngressStagePreviousResponseNotFound &&
+ turnPreviousResponseID != "" &&
+ !turnHasFunctionCallOutput &&
+ s.openAIWSIngressPreviousResponseRecoveryEnabled() &&
+ !wroteDownstream
+ if recoverablePrevNotFound {
+ // 可恢复场景使用非 error 关键字日志,避免被 LegacyPrintf 误判为 ERROR 级别。
+ logOpenAIWSModeInfo(
+ "ingress_ws_prev_response_recoverable account_id=%d turn=%d conn_id=%s idx=%d reason=%s code=%s type=%s message=%s previous_response_id=%s previous_response_id_kind=%s response_id=%s store_disabled=%v has_prompt_cache_key=%v",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen),
+ eventCount,
+ truncateOpenAIWSLogValue(fallbackReason, openAIWSLogValueMaxLen),
+ errCode,
+ errType,
+ errMessage,
+ truncateOpenAIWSLogValue(turnPreviousResponseID, openAIWSIDValueMaxLen),
+ normalizeOpenAIWSLogValue(turnPreviousResponseIDKind),
+ truncateOpenAIWSLogValue(responseID, openAIWSIDValueMaxLen),
+ turnStoreDisabled,
+ turnPromptCacheKey != "",
+ )
+ } else {
+ logOpenAIWSModeInfo(
+ "ingress_ws_error_event account_id=%d turn=%d conn_id=%s idx=%d fallback_reason=%s err_code=%s err_type=%s err_message=%s previous_response_id=%s previous_response_id_kind=%s response_id=%s store_disabled=%v has_prompt_cache_key=%v",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen),
+ eventCount,
+ truncateOpenAIWSLogValue(fallbackReason, openAIWSLogValueMaxLen),
+ errCode,
+ errType,
+ errMessage,
+ truncateOpenAIWSLogValue(turnPreviousResponseID, openAIWSIDValueMaxLen),
+ normalizeOpenAIWSLogValue(turnPreviousResponseIDKind),
+ truncateOpenAIWSLogValue(responseID, openAIWSIDValueMaxLen),
+ turnStoreDisabled,
+ turnPromptCacheKey != "",
+ )
+ }
+ // previous_response_not_found 在 ingress 模式支持单次恢复重试:
+ // 不把该 error 直接下发客户端,而是由上层去掉 previous_response_id 后重放当前 turn。
+ if recoverablePrevNotFound {
+ lease.MarkBroken()
+ errMsg := strings.TrimSpace(errMsgRaw)
+ if errMsg == "" {
+ errMsg = "previous response not found"
+ }
+ return nil, wrapOpenAIWSIngressTurnError(
+ openAIWSIngressStagePreviousResponseNotFound,
+ errors.New(errMsg),
+ false,
+ )
+ }
+ }
+ isTokenEvent := isOpenAIWSTokenEvent(eventType)
+ if isTokenEvent {
+ tokenEventCount++
+ }
+ isTerminalEvent := isOpenAIWSTerminalEvent(eventType)
+ if isTerminalEvent {
+ terminalEventCount++
+ }
+ if firstTokenMs == nil && isTokenEvent {
+ ms := int(time.Since(turnStart).Milliseconds())
+ firstTokenMs = &ms
+ }
+ if openAIWSEventShouldParseUsage(eventType) {
+ parseOpenAIWSResponseUsageFromCompletedEvent(upstreamMessage, &usage)
+ }
+
+ if !clientDisconnected {
+ if needModelReplace && len(mappedModelBytes) > 0 && openAIWSEventMayContainModel(eventType) && bytes.Contains(upstreamMessage, mappedModelBytes) {
+ upstreamMessage = replaceOpenAIWSMessageModel(upstreamMessage, mappedModel, originalModel)
+ }
+ if openAIWSEventMayContainToolCalls(eventType) && openAIWSMessageLikelyContainsToolCalls(upstreamMessage) {
+ if corrected, changed := s.toolCorrector.CorrectToolCallsInSSEBytes(upstreamMessage); changed {
+ upstreamMessage = corrected
+ }
+ }
+ if err := writeClientMessage(upstreamMessage); err != nil {
+ if isOpenAIWSClientDisconnectError(err) {
+ clientDisconnected = true
+ closeStatus, closeReason := summarizeOpenAIWSReadCloseError(err)
+ logOpenAIWSModeInfo(
+ "ingress_ws_client_disconnected_drain account_id=%d turn=%d conn_id=%s close_status=%s close_reason=%s",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen),
+ closeStatus,
+ truncateOpenAIWSLogValue(closeReason, openAIWSHeaderValueMaxLen),
+ )
+ } else {
+ return nil, wrapOpenAIWSIngressTurnError(
+ "write_client",
+ fmt.Errorf("write client websocket event: %w", err),
+ wroteDownstream,
+ )
+ }
+ } else {
+ wroteDownstream = true
+ }
+ }
+ if isTerminalEvent {
+ // 客户端已断连时,上游连接的 session 状态不可信,标记 broken 避免回池复用。
+ if clientDisconnected {
+ lease.MarkBroken()
+ }
+ firstTokenMsValue := -1
+ if firstTokenMs != nil {
+ firstTokenMsValue = *firstTokenMs
+ }
+ if debugEnabled {
+ logOpenAIWSModeDebug(
+ "ingress_ws_turn_completed account_id=%d turn=%d conn_id=%s response_id=%s duration_ms=%d events=%d token_events=%d terminal_events=%d first_event=%s last_event=%s first_token_ms=%d client_disconnected=%v",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen),
+ truncateOpenAIWSLogValue(responseID, openAIWSIDValueMaxLen),
+ time.Since(turnStart).Milliseconds(),
+ eventCount,
+ tokenEventCount,
+ terminalEventCount,
+ truncateOpenAIWSLogValue(firstEventType, openAIWSLogValueMaxLen),
+ truncateOpenAIWSLogValue(lastEventType, openAIWSLogValueMaxLen),
+ firstTokenMsValue,
+ clientDisconnected,
+ )
+ }
+ return &OpenAIForwardResult{
+ RequestID: responseID,
+ Usage: usage,
+ Model: originalModel,
+ ReasoningEffort: extractOpenAIReasoningEffortFromBody(payload, originalModel),
+ Stream: reqStream,
+ OpenAIWSMode: true,
+ Duration: time.Since(turnStart),
+ FirstTokenMs: firstTokenMs,
+ }, nil
+ }
+ }
+ }
+
+ currentPayload := firstPayload.payloadRaw
+ currentOriginalModel := firstPayload.originalModel
+ currentPayloadBytes := firstPayload.payloadBytes
+ isStrictAffinityTurn := func(payload []byte) bool {
+ if !storeDisabled {
+ return false
+ }
+ return strings.TrimSpace(openAIWSPayloadStringFromRaw(payload, "previous_response_id")) != ""
+ }
+ var sessionLease *openAIWSConnLease
+ sessionConnID := ""
+ pinnedSessionConnID := ""
+ unpinSessionConn := func(connID string) {
+ connID = strings.TrimSpace(connID)
+ if connID == "" || pinnedSessionConnID != connID {
+ return
+ }
+ pool.UnpinConn(account.ID, connID)
+ pinnedSessionConnID = ""
+ }
+ pinSessionConn := func(connID string) {
+ if !storeDisabled {
+ return
+ }
+ connID = strings.TrimSpace(connID)
+ if connID == "" || pinnedSessionConnID == connID {
+ return
+ }
+ if pinnedSessionConnID != "" {
+ pool.UnpinConn(account.ID, pinnedSessionConnID)
+ pinnedSessionConnID = ""
+ }
+ if pool.PinConn(account.ID, connID) {
+ pinnedSessionConnID = connID
+ }
+ }
+ releaseSessionLease := func() {
+ if sessionLease == nil {
+ return
+ }
+ if dedicatedMode {
+ // dedicated 会话结束后主动标记损坏,确保连接不会跨会话复用。
+ sessionLease.MarkBroken()
+ }
+ unpinSessionConn(sessionConnID)
+ sessionLease.Release()
+ if debugEnabled {
+ logOpenAIWSModeDebug(
+ "ingress_ws_upstream_released account_id=%d conn_id=%s",
+ account.ID,
+ truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen),
+ )
+ }
+ }
+ defer releaseSessionLease()
+
+ turn := 1
+ turnRetry := 0
+ turnPrevRecoveryTried := false
+ lastTurnFinishedAt := time.Time{}
+ lastTurnResponseID := ""
+ lastTurnPayload := []byte(nil)
+ var lastTurnStrictState *openAIWSIngressPreviousTurnStrictState
+ lastTurnReplayInput := []json.RawMessage(nil)
+ lastTurnReplayInputExists := false
+ currentTurnReplayInput := []json.RawMessage(nil)
+ currentTurnReplayInputExists := false
+ skipBeforeTurn := false
+ resetSessionLease := func(markBroken bool) {
+ if sessionLease == nil {
+ return
+ }
+ if markBroken {
+ sessionLease.MarkBroken()
+ }
+ releaseSessionLease()
+ sessionLease = nil
+ sessionConnID = ""
+ preferredConnID = ""
+ }
+ recoverIngressPrevResponseNotFound := func(relayErr error, turn int, connID string) bool {
+ if !isOpenAIWSIngressPreviousResponseNotFound(relayErr) {
+ return false
+ }
+ if turnPrevRecoveryTried || !s.openAIWSIngressPreviousResponseRecoveryEnabled() {
+ return false
+ }
+ if isStrictAffinityTurn(currentPayload) {
+ // Layer 2:严格亲和链路命中 previous_response_not_found 时,降级为“去掉 previous_response_id 后重放一次”。
+ // 该错误说明续链锚点已失效,继续 strict fail-close 只会直接中断本轮请求。
+ logOpenAIWSModeInfo(
+ "ingress_ws_prev_response_recovery_layer2 account_id=%d turn=%d conn_id=%s store_disabled_conn_mode=%s action=drop_previous_response_id_retry",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
+ normalizeOpenAIWSLogValue(storeDisabledConnMode),
+ )
+ }
+ turnPrevRecoveryTried = true
+ updatedPayload, removed, dropErr := dropPreviousResponseIDFromRawPayload(currentPayload)
+ if dropErr != nil || !removed {
+ reason := "not_removed"
+ if dropErr != nil {
+ reason = "drop_error"
+ }
+ logOpenAIWSModeInfo(
+ "ingress_ws_prev_response_recovery_skip account_id=%d turn=%d conn_id=%s reason=%s",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
+ normalizeOpenAIWSLogValue(reason),
+ )
+ return false
+ }
+ updatedWithInput, setInputErr := setOpenAIWSPayloadInputSequence(
+ updatedPayload,
+ currentTurnReplayInput,
+ currentTurnReplayInputExists,
+ )
+ if setInputErr != nil {
+ logOpenAIWSModeInfo(
+ "ingress_ws_prev_response_recovery_skip account_id=%d turn=%d conn_id=%s reason=set_full_input_error cause=%s",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
+ truncateOpenAIWSLogValue(setInputErr.Error(), openAIWSLogValueMaxLen),
+ )
+ return false
+ }
+ logOpenAIWSModeInfo(
+ "ingress_ws_prev_response_recovery account_id=%d turn=%d conn_id=%s action=drop_previous_response_id retry=1",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
+ )
+ currentPayload = updatedWithInput
+ currentPayloadBytes = len(updatedWithInput)
+ resetSessionLease(true)
+ skipBeforeTurn = true
+ return true
+ }
+ retryIngressTurn := func(relayErr error, turn int, connID string) bool {
+ if !isOpenAIWSIngressTurnRetryable(relayErr) || turnRetry >= 1 {
+ return false
+ }
+ if isStrictAffinityTurn(currentPayload) {
+ logOpenAIWSModeInfo(
+ "ingress_ws_turn_retry_skip account_id=%d turn=%d conn_id=%s reason=strict_affinity",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
+ )
+ return false
+ }
+ turnRetry++
+ logOpenAIWSModeInfo(
+ "ingress_ws_turn_retry account_id=%d turn=%d retry=%d reason=%s conn_id=%s",
+ account.ID,
+ turn,
+ turnRetry,
+ truncateOpenAIWSLogValue(openAIWSIngressTurnRetryReason(relayErr), openAIWSLogValueMaxLen),
+ truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
+ )
+ resetSessionLease(true)
+ skipBeforeTurn = true
+ return true
+ }
+ for {
+ if !skipBeforeTurn && hooks != nil && hooks.BeforeTurn != nil {
+ if err := hooks.BeforeTurn(turn); err != nil {
+ return err
+ }
+ }
+ skipBeforeTurn = false
+ currentPreviousResponseID := openAIWSPayloadStringFromRaw(currentPayload, "previous_response_id")
+ expectedPrev := strings.TrimSpace(lastTurnResponseID)
+ hasFunctionCallOutput := gjson.GetBytes(currentPayload, `input.#(type=="function_call_output")`).Exists()
+ // store=false + function_call_output 场景必须有续链锚点。
+ // 若客户端未传 previous_response_id,优先回填上一轮响应 ID,避免上游报 call_id 无法关联。
+ if shouldInferIngressFunctionCallOutputPreviousResponseID(
+ storeDisabled,
+ turn,
+ hasFunctionCallOutput,
+ currentPreviousResponseID,
+ expectedPrev,
+ ) {
+ updatedPayload, setPrevErr := setPreviousResponseIDToRawPayload(currentPayload, expectedPrev)
+ if setPrevErr != nil {
+ logOpenAIWSModeInfo(
+ "ingress_ws_function_call_output_prev_infer_skip account_id=%d turn=%d conn_id=%s reason=set_previous_response_id_error cause=%s expected_previous_response_id=%s",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen),
+ truncateOpenAIWSLogValue(setPrevErr.Error(), openAIWSLogValueMaxLen),
+ truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen),
+ )
+ } else {
+ currentPayload = updatedPayload
+ currentPayloadBytes = len(updatedPayload)
+ currentPreviousResponseID = expectedPrev
+ logOpenAIWSModeInfo(
+ "ingress_ws_function_call_output_prev_infer account_id=%d turn=%d conn_id=%s action=set_previous_response_id previous_response_id=%s",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen),
+ truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen),
+ )
+ }
+ }
+ nextReplayInput, nextReplayInputExists, replayInputErr := buildOpenAIWSReplayInputSequence(
+ lastTurnReplayInput,
+ lastTurnReplayInputExists,
+ currentPayload,
+ currentPreviousResponseID != "",
+ )
+ if replayInputErr != nil {
+ logOpenAIWSModeInfo(
+ "ingress_ws_replay_input_skip account_id=%d turn=%d conn_id=%s reason=build_error cause=%s",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen),
+ truncateOpenAIWSLogValue(replayInputErr.Error(), openAIWSLogValueMaxLen),
+ )
+ currentTurnReplayInput = nil
+ currentTurnReplayInputExists = false
+ } else {
+ currentTurnReplayInput = nextReplayInput
+ currentTurnReplayInputExists = nextReplayInputExists
+ }
+ if storeDisabled && turn > 1 && currentPreviousResponseID != "" {
+ shouldKeepPreviousResponseID := false
+ strictReason := ""
+ var strictErr error
+ if lastTurnStrictState != nil {
+ shouldKeepPreviousResponseID, strictReason, strictErr = shouldKeepIngressPreviousResponseIDWithStrictState(
+ lastTurnStrictState,
+ currentPayload,
+ lastTurnResponseID,
+ hasFunctionCallOutput,
+ )
+ } else {
+ shouldKeepPreviousResponseID, strictReason, strictErr = shouldKeepIngressPreviousResponseID(
+ lastTurnPayload,
+ currentPayload,
+ lastTurnResponseID,
+ hasFunctionCallOutput,
+ )
+ }
+ if strictErr != nil {
+ logOpenAIWSModeInfo(
+ "ingress_ws_prev_response_strict_eval account_id=%d turn=%d conn_id=%s action=keep_previous_response_id reason=%s cause=%s previous_response_id=%s expected_previous_response_id=%s has_function_call_output=%v",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen),
+ normalizeOpenAIWSLogValue(strictReason),
+ truncateOpenAIWSLogValue(strictErr.Error(), openAIWSLogValueMaxLen),
+ truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen),
+ truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen),
+ hasFunctionCallOutput,
+ )
+ } else if !shouldKeepPreviousResponseID {
+ updatedPayload, removed, dropErr := dropPreviousResponseIDFromRawPayload(currentPayload)
+ if dropErr != nil || !removed {
+ dropReason := "not_removed"
+ if dropErr != nil {
+ dropReason = "drop_error"
+ }
+ logOpenAIWSModeInfo(
+ "ingress_ws_prev_response_strict_eval account_id=%d turn=%d conn_id=%s action=keep_previous_response_id reason=%s drop_reason=%s previous_response_id=%s expected_previous_response_id=%s has_function_call_output=%v",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen),
+ normalizeOpenAIWSLogValue(strictReason),
+ normalizeOpenAIWSLogValue(dropReason),
+ truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen),
+ truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen),
+ hasFunctionCallOutput,
+ )
+ } else {
+ updatedWithInput, setInputErr := setOpenAIWSPayloadInputSequence(
+ updatedPayload,
+ currentTurnReplayInput,
+ currentTurnReplayInputExists,
+ )
+ if setInputErr != nil {
+ logOpenAIWSModeInfo(
+ "ingress_ws_prev_response_strict_eval account_id=%d turn=%d conn_id=%s action=keep_previous_response_id reason=%s drop_reason=set_full_input_error previous_response_id=%s expected_previous_response_id=%s cause=%s has_function_call_output=%v",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen),
+ normalizeOpenAIWSLogValue(strictReason),
+ truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen),
+ truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen),
+ truncateOpenAIWSLogValue(setInputErr.Error(), openAIWSLogValueMaxLen),
+ hasFunctionCallOutput,
+ )
+ } else {
+ currentPayload = updatedWithInput
+ currentPayloadBytes = len(updatedWithInput)
+ logOpenAIWSModeInfo(
+ "ingress_ws_prev_response_strict_eval account_id=%d turn=%d conn_id=%s action=drop_previous_response_id_full_create reason=%s previous_response_id=%s expected_previous_response_id=%s has_function_call_output=%v",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen),
+ normalizeOpenAIWSLogValue(strictReason),
+ truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen),
+ truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen),
+ hasFunctionCallOutput,
+ )
+ currentPreviousResponseID = ""
+ }
+ }
+ }
+ }
+ forcePreferredConn := isStrictAffinityTurn(currentPayload)
+ if sessionLease == nil {
+ acquiredLease, acquireErr := acquireTurnLease(turn, preferredConnID, forcePreferredConn)
+ if acquireErr != nil {
+ return fmt.Errorf("acquire upstream websocket: %w", acquireErr)
+ }
+ sessionLease = acquiredLease
+ sessionConnID = strings.TrimSpace(sessionLease.ConnID())
+ if storeDisabled {
+ pinSessionConn(sessionConnID)
+ } else {
+ unpinSessionConn(sessionConnID)
+ }
+ }
+ shouldPreflightPing := turn > 1 && sessionLease != nil && turnRetry == 0
+ if shouldPreflightPing && openAIWSIngressPreflightPingIdle > 0 && !lastTurnFinishedAt.IsZero() {
+ if time.Since(lastTurnFinishedAt) < openAIWSIngressPreflightPingIdle {
+ shouldPreflightPing = false
+ }
+ }
+ if shouldPreflightPing {
+ if pingErr := sessionLease.PingWithTimeout(openAIWSConnHealthCheckTO); pingErr != nil {
+ logOpenAIWSModeInfo(
+ "ingress_ws_upstream_preflight_ping_fail account_id=%d turn=%d conn_id=%s cause=%s",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen),
+ truncateOpenAIWSLogValue(pingErr.Error(), openAIWSLogValueMaxLen),
+ )
+ if forcePreferredConn {
+ if !turnPrevRecoveryTried && currentPreviousResponseID != "" {
+ updatedPayload, removed, dropErr := dropPreviousResponseIDFromRawPayload(currentPayload)
+ if dropErr != nil || !removed {
+ reason := "not_removed"
+ if dropErr != nil {
+ reason = "drop_error"
+ }
+ logOpenAIWSModeInfo(
+ "ingress_ws_preflight_ping_recovery_skip account_id=%d turn=%d conn_id=%s reason=%s previous_response_id=%s",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen),
+ normalizeOpenAIWSLogValue(reason),
+ truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen),
+ )
+ } else {
+ updatedWithInput, setInputErr := setOpenAIWSPayloadInputSequence(
+ updatedPayload,
+ currentTurnReplayInput,
+ currentTurnReplayInputExists,
+ )
+ if setInputErr != nil {
+ logOpenAIWSModeInfo(
+ "ingress_ws_preflight_ping_recovery_skip account_id=%d turn=%d conn_id=%s reason=set_full_input_error previous_response_id=%s cause=%s",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen),
+ truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen),
+ truncateOpenAIWSLogValue(setInputErr.Error(), openAIWSLogValueMaxLen),
+ )
+ } else {
+ logOpenAIWSModeInfo(
+ "ingress_ws_preflight_ping_recovery account_id=%d turn=%d conn_id=%s action=drop_previous_response_id_retry previous_response_id=%s",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen),
+ truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen),
+ )
+ turnPrevRecoveryTried = true
+ currentPayload = updatedWithInput
+ currentPayloadBytes = len(updatedWithInput)
+ resetSessionLease(true)
+ skipBeforeTurn = true
+ continue
+ }
+ }
+ }
+ resetSessionLease(true)
+ return NewOpenAIWSClientCloseError(
+ coderws.StatusPolicyViolation,
+ "upstream continuation connection is unavailable; please restart the conversation",
+ pingErr,
+ )
+ }
+ resetSessionLease(true)
+
+ acquiredLease, acquireErr := acquireTurnLease(turn, preferredConnID, forcePreferredConn)
+ if acquireErr != nil {
+ return fmt.Errorf("acquire upstream websocket after preflight ping fail: %w", acquireErr)
+ }
+ sessionLease = acquiredLease
+ sessionConnID = strings.TrimSpace(sessionLease.ConnID())
+ if storeDisabled {
+ pinSessionConn(sessionConnID)
+ }
+ }
+ }
+ connID := sessionConnID
+ if currentPreviousResponseID != "" {
+ chainedFromLast := expectedPrev != "" && currentPreviousResponseID == expectedPrev
+ currentPreviousResponseIDKind := ClassifyOpenAIPreviousResponseIDKind(currentPreviousResponseID)
+ logOpenAIWSModeInfo(
+ "ingress_ws_turn_chain account_id=%d turn=%d conn_id=%s previous_response_id=%s previous_response_id_kind=%s last_turn_response_id=%s chained_from_last=%v preferred_conn_id=%s header_session_id=%s header_conversation_id=%s has_turn_state=%v turn_state_len=%d has_prompt_cache_key=%v store_disabled=%v",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
+ truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen),
+ normalizeOpenAIWSLogValue(currentPreviousResponseIDKind),
+ truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen),
+ chainedFromLast,
+ truncateOpenAIWSLogValue(preferredConnID, openAIWSIDValueMaxLen),
+ openAIWSHeaderValueForLog(baseAcquireReq.Headers, "session_id"),
+ openAIWSHeaderValueForLog(baseAcquireReq.Headers, "conversation_id"),
+ turnState != "",
+ len(turnState),
+ openAIWSPayloadStringFromRaw(currentPayload, "prompt_cache_key") != "",
+ storeDisabled,
+ )
+ }
+
+ result, relayErr := sendAndRelay(turn, sessionLease, currentPayload, currentPayloadBytes, currentOriginalModel)
+ if relayErr != nil {
+ if recoverIngressPrevResponseNotFound(relayErr, turn, connID) {
+ continue
+ }
+ if retryIngressTurn(relayErr, turn, connID) {
+ continue
+ }
+ finalErr := relayErr
+ if unwrapped := errors.Unwrap(relayErr); unwrapped != nil {
+ finalErr = unwrapped
+ }
+ if hooks != nil && hooks.AfterTurn != nil {
+ hooks.AfterTurn(turn, nil, finalErr)
+ }
+ sessionLease.MarkBroken()
+ return finalErr
+ }
+ turnRetry = 0
+ turnPrevRecoveryTried = false
+ lastTurnFinishedAt = time.Now()
+ if hooks != nil && hooks.AfterTurn != nil {
+ hooks.AfterTurn(turn, result, nil)
+ }
+ if result == nil {
+ return errors.New("websocket turn result is nil")
+ }
+ responseID := strings.TrimSpace(result.RequestID)
+ lastTurnResponseID = responseID
+ lastTurnPayload = cloneOpenAIWSPayloadBytes(currentPayload)
+ lastTurnReplayInput = cloneOpenAIWSRawMessages(currentTurnReplayInput)
+ lastTurnReplayInputExists = currentTurnReplayInputExists
+ nextStrictState, strictStateErr := buildOpenAIWSIngressPreviousTurnStrictState(currentPayload)
+ if strictStateErr != nil {
+ lastTurnStrictState = nil
+ logOpenAIWSModeInfo(
+ "ingress_ws_prev_response_strict_state_skip account_id=%d turn=%d conn_id=%s reason=build_error cause=%s",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
+ truncateOpenAIWSLogValue(strictStateErr.Error(), openAIWSLogValueMaxLen),
+ )
+ } else {
+ lastTurnStrictState = nextStrictState
+ }
+
+ if responseID != "" && stateStore != nil {
+ ttl := s.openAIWSResponseStickyTTL()
+ logOpenAIWSBindResponseAccountWarn(groupID, account.ID, responseID, stateStore.BindResponseAccount(ctx, groupID, responseID, account.ID, ttl))
+ stateStore.BindResponseConn(responseID, connID, ttl)
+ }
+ if stateStore != nil && storeDisabled && sessionHash != "" {
+ stateStore.BindSessionConn(groupID, sessionHash, connID, s.openAIWSSessionStickyTTL())
+ }
+ if connID != "" {
+ preferredConnID = connID
+ }
+
+ nextClientMessage, readErr := readClientMessage()
+ if readErr != nil {
+ if isOpenAIWSClientDisconnectError(readErr) {
+ closeStatus, closeReason := summarizeOpenAIWSReadCloseError(readErr)
+ logOpenAIWSModeInfo(
+ "ingress_ws_client_closed account_id=%d conn_id=%s close_status=%s close_reason=%s",
+ account.ID,
+ truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
+ closeStatus,
+ truncateOpenAIWSLogValue(closeReason, openAIWSHeaderValueMaxLen),
+ )
+ return nil
+ }
+ return fmt.Errorf("read client websocket request: %w", readErr)
+ }
+
+ nextPayload, parseErr := parseClientPayload(nextClientMessage)
+ if parseErr != nil {
+ return parseErr
+ }
+ if nextPayload.promptCacheKey != "" {
+ // ingress 会话在整个客户端 WS 生命周期内复用同一上游连接;
+ // prompt_cache_key 对握手头的更新仅在未来需要重新建连时生效。
+ updatedHeaders, _ := s.buildOpenAIWSHeaders(c, account, token, wsDecision, isCodexCLI, turnState, strings.TrimSpace(c.GetHeader(openAIWSTurnMetadataHeader)), nextPayload.promptCacheKey)
+ baseAcquireReq.Headers = updatedHeaders
+ }
+ if nextPayload.previousResponseID != "" {
+ expectedPrev := strings.TrimSpace(lastTurnResponseID)
+ chainedFromLast := expectedPrev != "" && nextPayload.previousResponseID == expectedPrev
+ nextPreviousResponseIDKind := ClassifyOpenAIPreviousResponseIDKind(nextPayload.previousResponseID)
+ logOpenAIWSModeInfo(
+ "ingress_ws_next_turn_chain account_id=%d turn=%d next_turn=%d conn_id=%s previous_response_id=%s previous_response_id_kind=%s last_turn_response_id=%s chained_from_last=%v has_prompt_cache_key=%v store_disabled=%v",
+ account.ID,
+ turn,
+ turn+1,
+ truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
+ truncateOpenAIWSLogValue(nextPayload.previousResponseID, openAIWSIDValueMaxLen),
+ normalizeOpenAIWSLogValue(nextPreviousResponseIDKind),
+ truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen),
+ chainedFromLast,
+ nextPayload.promptCacheKey != "",
+ storeDisabled,
+ )
+ }
+ if stateStore != nil && nextPayload.previousResponseID != "" {
+ if stickyConnID, ok := stateStore.GetResponseConn(nextPayload.previousResponseID); ok {
+ if sessionConnID != "" && stickyConnID != "" && stickyConnID != sessionConnID {
+ logOpenAIWSModeInfo(
+ "ingress_ws_keep_session_conn account_id=%d turn=%d conn_id=%s sticky_conn_id=%s previous_response_id=%s",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen),
+ truncateOpenAIWSLogValue(stickyConnID, openAIWSIDValueMaxLen),
+ truncateOpenAIWSLogValue(nextPayload.previousResponseID, openAIWSIDValueMaxLen),
+ )
+ } else {
+ preferredConnID = stickyConnID
+ }
+ }
+ }
+ currentPayload = nextPayload.payloadRaw
+ currentOriginalModel = nextPayload.originalModel
+ currentPayloadBytes = nextPayload.payloadBytes
+ storeDisabled = s.isOpenAIWSStoreDisabledInRequestRaw(currentPayload, account)
+ if !storeDisabled {
+ unpinSessionConn(sessionConnID)
+ }
+ turn++
+ }
+}
+
+func (s *OpenAIGatewayService) isOpenAIWSGeneratePrewarmEnabled() bool {
+ return s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.PrewarmGenerateEnabled
+}
+
+// performOpenAIWSGeneratePrewarm 在 WSv2 下执行可选的 generate=false 预热。
+// 预热默认关闭,仅在配置开启后生效;失败时按可恢复错误回退到 HTTP。
+func (s *OpenAIGatewayService) performOpenAIWSGeneratePrewarm(
+ ctx context.Context,
+ lease *openAIWSConnLease,
+ decision OpenAIWSProtocolDecision,
+ payload map[string]any,
+ previousResponseID string,
+ reqBody map[string]any,
+ account *Account,
+ stateStore OpenAIWSStateStore,
+ groupID int64,
+) error {
+ if s == nil {
+ return nil
+ }
+ if lease == nil || account == nil {
+ logOpenAIWSModeInfo("prewarm_skip reason=invalid_state has_lease=%v has_account=%v", lease != nil, account != nil)
+ return nil
+ }
+ connID := strings.TrimSpace(lease.ConnID())
+ if !s.isOpenAIWSGeneratePrewarmEnabled() {
+ return nil
+ }
+ if decision.Transport != OpenAIUpstreamTransportResponsesWebsocketV2 {
+ logOpenAIWSModeInfo(
+ "prewarm_skip account_id=%d conn_id=%s reason=transport_not_v2 transport=%s",
+ account.ID,
+ connID,
+ normalizeOpenAIWSLogValue(string(decision.Transport)),
+ )
+ return nil
+ }
+ if strings.TrimSpace(previousResponseID) != "" {
+ logOpenAIWSModeInfo(
+ "prewarm_skip account_id=%d conn_id=%s reason=has_previous_response_id previous_response_id=%s",
+ account.ID,
+ connID,
+ truncateOpenAIWSLogValue(previousResponseID, openAIWSIDValueMaxLen),
+ )
+ return nil
+ }
+ if lease.IsPrewarmed() {
+ logOpenAIWSModeInfo("prewarm_skip account_id=%d conn_id=%s reason=already_prewarmed", account.ID, connID)
+ return nil
+ }
+ if NeedsToolContinuation(reqBody) {
+ logOpenAIWSModeInfo("prewarm_skip account_id=%d conn_id=%s reason=tool_continuation", account.ID, connID)
+ return nil
+ }
+ prewarmStart := time.Now()
+ logOpenAIWSModeInfo("prewarm_start account_id=%d conn_id=%s", account.ID, connID)
+
+ prewarmPayload := make(map[string]any, len(payload)+1)
+ for k, v := range payload {
+ prewarmPayload[k] = v
+ }
+ prewarmPayload["generate"] = false
+ prewarmPayloadJSON := payloadAsJSONBytes(prewarmPayload)
+
+ if err := lease.WriteJSONWithContextTimeout(ctx, prewarmPayload, s.openAIWSWriteTimeout()); err != nil {
+ lease.MarkBroken()
+ logOpenAIWSModeInfo(
+ "prewarm_write_fail account_id=%d conn_id=%s cause=%s",
+ account.ID,
+ connID,
+ truncateOpenAIWSLogValue(err.Error(), openAIWSLogValueMaxLen),
+ )
+ return wrapOpenAIWSFallback("prewarm_write", err)
+ }
+ logOpenAIWSModeInfo("prewarm_write_sent account_id=%d conn_id=%s payload_bytes=%d", account.ID, connID, len(prewarmPayloadJSON))
+
+ prewarmResponseID := ""
+ prewarmEventCount := 0
+ prewarmTerminalCount := 0
+ for {
+ message, readErr := lease.ReadMessageWithContextTimeout(ctx, s.openAIWSReadTimeout())
+ if readErr != nil {
+ lease.MarkBroken()
+ closeStatus, closeReason := summarizeOpenAIWSReadCloseError(readErr)
+ logOpenAIWSModeInfo(
+ "prewarm_read_fail account_id=%d conn_id=%s close_status=%s close_reason=%s cause=%s events=%d",
+ account.ID,
+ connID,
+ closeStatus,
+ closeReason,
+ truncateOpenAIWSLogValue(readErr.Error(), openAIWSLogValueMaxLen),
+ prewarmEventCount,
+ )
+ return wrapOpenAIWSFallback("prewarm_"+classifyOpenAIWSReadFallbackReason(readErr), readErr)
+ }
+
+ eventType, eventResponseID, _ := parseOpenAIWSEventEnvelope(message)
+ if eventType == "" {
+ continue
+ }
+ prewarmEventCount++
+ if prewarmResponseID == "" && eventResponseID != "" {
+ prewarmResponseID = eventResponseID
+ }
+ if prewarmEventCount <= openAIWSPrewarmEventLogHead || eventType == "error" || isOpenAIWSTerminalEvent(eventType) {
+ logOpenAIWSModeInfo(
+ "prewarm_event account_id=%d conn_id=%s idx=%d type=%s bytes=%d",
+ account.ID,
+ connID,
+ prewarmEventCount,
+ truncateOpenAIWSLogValue(eventType, openAIWSLogValueMaxLen),
+ len(message),
+ )
+ }
+
+ if eventType == "error" {
+ errCodeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(message)
+ errMsg := strings.TrimSpace(errMsgRaw)
+ if errMsg == "" {
+ errMsg = "OpenAI websocket prewarm error"
+ }
+ fallbackReason, canFallback := classifyOpenAIWSErrorEventFromRaw(errCodeRaw, errTypeRaw, errMsgRaw)
+ errCode, errType, errMessage := summarizeOpenAIWSErrorEventFieldsFromRaw(errCodeRaw, errTypeRaw, errMsgRaw)
+ logOpenAIWSModeInfo(
+ "prewarm_error_event account_id=%d conn_id=%s idx=%d fallback_reason=%s can_fallback=%v err_code=%s err_type=%s err_message=%s",
+ account.ID,
+ connID,
+ prewarmEventCount,
+ truncateOpenAIWSLogValue(fallbackReason, openAIWSLogValueMaxLen),
+ canFallback,
+ errCode,
+ errType,
+ errMessage,
+ )
+ lease.MarkBroken()
+ if canFallback {
+ return wrapOpenAIWSFallback("prewarm_"+fallbackReason, errors.New(errMsg))
+ }
+ return wrapOpenAIWSFallback("prewarm_error_event", errors.New(errMsg))
+ }
+
+ if isOpenAIWSTerminalEvent(eventType) {
+ prewarmTerminalCount++
+ break
+ }
+ }
+
+ lease.MarkPrewarmed()
+ if prewarmResponseID != "" && stateStore != nil {
+ ttl := s.openAIWSResponseStickyTTL()
+ logOpenAIWSBindResponseAccountWarn(groupID, account.ID, prewarmResponseID, stateStore.BindResponseAccount(ctx, groupID, prewarmResponseID, account.ID, ttl))
+ stateStore.BindResponseConn(prewarmResponseID, lease.ConnID(), ttl)
+ }
+ logOpenAIWSModeInfo(
+ "prewarm_done account_id=%d conn_id=%s response_id=%s events=%d terminal_events=%d duration_ms=%d",
+ account.ID,
+ connID,
+ truncateOpenAIWSLogValue(prewarmResponseID, openAIWSIDValueMaxLen),
+ prewarmEventCount,
+ prewarmTerminalCount,
+ time.Since(prewarmStart).Milliseconds(),
+ )
+ return nil
+}
+
+func payloadAsJSON(payload map[string]any) string {
+ return string(payloadAsJSONBytes(payload))
+}
+
+func payloadAsJSONBytes(payload map[string]any) []byte {
+ if len(payload) == 0 {
+ return []byte("{}")
+ }
+ body, err := json.Marshal(payload)
+ if err != nil {
+ return []byte("{}")
+ }
+ return body
+}
+
+func isOpenAIWSTerminalEvent(eventType string) bool {
+ switch strings.TrimSpace(eventType) {
+ case "response.completed", "response.done", "response.failed", "response.incomplete", "response.cancelled", "response.canceled":
+ return true
+ default:
+ return false
+ }
+}
+
+func isOpenAIWSTokenEvent(eventType string) bool {
+ eventType = strings.TrimSpace(eventType)
+ if eventType == "" {
+ return false
+ }
+ switch eventType {
+ case "response.created", "response.in_progress", "response.output_item.added", "response.output_item.done":
+ return false
+ }
+ if strings.Contains(eventType, ".delta") {
+ return true
+ }
+ if strings.HasPrefix(eventType, "response.output_text") {
+ return true
+ }
+ if strings.HasPrefix(eventType, "response.output") {
+ return true
+ }
+ return eventType == "response.completed" || eventType == "response.done"
+}
+
+func replaceOpenAIWSMessageModel(message []byte, fromModel, toModel string) []byte {
+ if len(message) == 0 {
+ return message
+ }
+ if strings.TrimSpace(fromModel) == "" || strings.TrimSpace(toModel) == "" || fromModel == toModel {
+ return message
+ }
+ if !bytes.Contains(message, []byte(`"model"`)) || !bytes.Contains(message, []byte(fromModel)) {
+ return message
+ }
+ modelValues := gjson.GetManyBytes(message, "model", "response.model")
+ replaceModel := modelValues[0].Exists() && modelValues[0].Str == fromModel
+ replaceResponseModel := modelValues[1].Exists() && modelValues[1].Str == fromModel
+ if !replaceModel && !replaceResponseModel {
+ return message
+ }
+ updated := message
+ if replaceModel {
+ if next, err := sjson.SetBytes(updated, "model", toModel); err == nil {
+ updated = next
+ }
+ }
+ if replaceResponseModel {
+ if next, err := sjson.SetBytes(updated, "response.model", toModel); err == nil {
+ updated = next
+ }
+ }
+ return updated
+}
+
+func populateOpenAIUsageFromResponseJSON(body []byte, usage *OpenAIUsage) {
+ if usage == nil || len(body) == 0 {
+ return
+ }
+ values := gjson.GetManyBytes(
+ body,
+ "usage.input_tokens",
+ "usage.output_tokens",
+ "usage.input_tokens_details.cached_tokens",
+ )
+ usage.InputTokens = int(values[0].Int())
+ usage.OutputTokens = int(values[1].Int())
+ usage.CacheReadInputTokens = int(values[2].Int())
+}
+
+func getOpenAIGroupIDFromContext(c *gin.Context) int64 {
+ if c == nil {
+ return 0
+ }
+ value, exists := c.Get("api_key")
+ if !exists {
+ return 0
+ }
+ apiKey, ok := value.(*APIKey)
+ if !ok || apiKey == nil || apiKey.GroupID == nil {
+ return 0
+ }
+ return *apiKey.GroupID
+}
+
+// SelectAccountByPreviousResponseID 按 previous_response_id 命中账号粘连。
+// 未命中或账号不可用时返回 (nil, nil),由调用方继续走常规调度。
+func (s *OpenAIGatewayService) SelectAccountByPreviousResponseID(
+ ctx context.Context,
+ groupID *int64,
+ previousResponseID string,
+ requestedModel string,
+ excludedIDs map[int64]struct{},
+) (*AccountSelectionResult, error) {
+ if s == nil {
+ return nil, nil
+ }
+ responseID := strings.TrimSpace(previousResponseID)
+ if responseID == "" {
+ return nil, nil
+ }
+ store := s.getOpenAIWSStateStore()
+ if store == nil {
+ return nil, nil
+ }
+
+ accountID, err := store.GetResponseAccount(ctx, derefGroupID(groupID), responseID)
+ if err != nil || accountID <= 0 {
+ return nil, nil
+ }
+ if excludedIDs != nil {
+ if _, excluded := excludedIDs[accountID]; excluded {
+ return nil, nil
+ }
+ }
+
+ account, err := s.getSchedulableAccount(ctx, accountID)
+ if err != nil || account == nil {
+ _ = store.DeleteResponseAccount(ctx, derefGroupID(groupID), responseID)
+ return nil, nil
+ }
+ // 非 WSv2 场景(如 force_http/全局关闭)不应使用 previous_response_id 粘连,
+ // 以保持“回滚到 HTTP”后的历史行为一致性。
+ if s.getOpenAIWSProtocolResolver().Resolve(account).Transport != OpenAIUpstreamTransportResponsesWebsocketV2 {
+ return nil, nil
+ }
+ if shouldClearStickySession(account, requestedModel) || !account.IsOpenAI() {
+ _ = store.DeleteResponseAccount(ctx, derefGroupID(groupID), responseID)
+ return nil, nil
+ }
+ if requestedModel != "" && !account.IsModelSupported(requestedModel) {
+ return nil, nil
+ }
+
+ result, acquireErr := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
+ if acquireErr == nil && result.Acquired {
+ logOpenAIWSBindResponseAccountWarn(
+ derefGroupID(groupID),
+ accountID,
+ responseID,
+ store.BindResponseAccount(ctx, derefGroupID(groupID), responseID, accountID, s.openAIWSResponseStickyTTL()),
+ )
+ return &AccountSelectionResult{
+ Account: account,
+ Acquired: true,
+ ReleaseFunc: result.ReleaseFunc,
+ }, nil
+ }
+
+ cfg := s.schedulingConfig()
+ if s.concurrencyService != nil {
+ return &AccountSelectionResult{
+ Account: account,
+ WaitPlan: &AccountWaitPlan{
+ AccountID: accountID,
+ MaxConcurrency: account.Concurrency,
+ Timeout: cfg.StickySessionWaitTimeout,
+ MaxWaiting: cfg.StickySessionMaxWaiting,
+ },
+ }, nil
+ }
+ return nil, nil
+}
+
+func classifyOpenAIWSAcquireError(err error) string {
+ if err == nil {
+ return "acquire_conn"
+ }
+ var dialErr *openAIWSDialError
+ if errors.As(err, &dialErr) {
+ switch dialErr.StatusCode {
+ case 426:
+ return "upgrade_required"
+ case 401, 403:
+ return "auth_failed"
+ case 429:
+ return "upstream_rate_limited"
+ }
+ if dialErr.StatusCode >= 500 {
+ return "upstream_5xx"
+ }
+ return "dial_failed"
+ }
+ if errors.Is(err, errOpenAIWSConnQueueFull) {
+ return "conn_queue_full"
+ }
+ if errors.Is(err, errOpenAIWSPreferredConnUnavailable) {
+ return "preferred_conn_unavailable"
+ }
+ if errors.Is(err, context.DeadlineExceeded) {
+ return "acquire_timeout"
+ }
+ return "acquire_conn"
+}
+
+func classifyOpenAIWSErrorEventFromRaw(codeRaw, errTypeRaw, msgRaw string) (string, bool) {
+ code := strings.ToLower(strings.TrimSpace(codeRaw))
+ errType := strings.ToLower(strings.TrimSpace(errTypeRaw))
+ msg := strings.ToLower(strings.TrimSpace(msgRaw))
+
+ switch code {
+ case "upgrade_required":
+ return "upgrade_required", true
+ case "websocket_not_supported", "websocket_unsupported":
+ return "ws_unsupported", true
+ case "websocket_connection_limit_reached":
+ return "ws_connection_limit_reached", true
+ case "previous_response_not_found":
+ return "previous_response_not_found", true
+ }
+ if strings.Contains(msg, "upgrade required") || strings.Contains(msg, "status 426") {
+ return "upgrade_required", true
+ }
+ if strings.Contains(errType, "upgrade") {
+ return "upgrade_required", true
+ }
+ if strings.Contains(msg, "websocket") && strings.Contains(msg, "unsupported") {
+ return "ws_unsupported", true
+ }
+ if strings.Contains(msg, "connection limit") && strings.Contains(msg, "websocket") {
+ return "ws_connection_limit_reached", true
+ }
+ if strings.Contains(msg, "previous_response_not_found") ||
+ (strings.Contains(msg, "previous response") && strings.Contains(msg, "not found")) {
+ return "previous_response_not_found", true
+ }
+ if strings.Contains(errType, "server_error") || strings.Contains(code, "server_error") {
+ return "upstream_error_event", true
+ }
+ return "event_error", false
+}
+
+func classifyOpenAIWSErrorEvent(message []byte) (string, bool) {
+ if len(message) == 0 {
+ return "event_error", false
+ }
+ return classifyOpenAIWSErrorEventFromRaw(parseOpenAIWSErrorEventFields(message))
+}
+
+func openAIWSErrorHTTPStatusFromRaw(codeRaw, errTypeRaw string) int {
+ code := strings.ToLower(strings.TrimSpace(codeRaw))
+ errType := strings.ToLower(strings.TrimSpace(errTypeRaw))
+ switch {
+ case strings.Contains(errType, "invalid_request"),
+ strings.Contains(code, "invalid_request"),
+ strings.Contains(code, "bad_request"),
+ code == "previous_response_not_found":
+ return http.StatusBadRequest
+ case strings.Contains(errType, "authentication"),
+ strings.Contains(code, "invalid_api_key"),
+ strings.Contains(code, "unauthorized"):
+ return http.StatusUnauthorized
+ case strings.Contains(errType, "permission"),
+ strings.Contains(code, "forbidden"):
+ return http.StatusForbidden
+ case strings.Contains(errType, "rate_limit"),
+ strings.Contains(code, "rate_limit"),
+ strings.Contains(code, "insufficient_quota"):
+ return http.StatusTooManyRequests
+ default:
+ return http.StatusBadGateway
+ }
+}
+
+func openAIWSErrorHTTPStatus(message []byte) int {
+ if len(message) == 0 {
+ return http.StatusBadGateway
+ }
+ codeRaw, errTypeRaw, _ := parseOpenAIWSErrorEventFields(message)
+ return openAIWSErrorHTTPStatusFromRaw(codeRaw, errTypeRaw)
+}
+
+func (s *OpenAIGatewayService) openAIWSFallbackCooldown() time.Duration {
+ if s == nil || s.cfg == nil {
+ return 30 * time.Second
+ }
+ seconds := s.cfg.Gateway.OpenAIWS.FallbackCooldownSeconds
+ if seconds <= 0 {
+ return 0
+ }
+ return time.Duration(seconds) * time.Second
+}
+
+func (s *OpenAIGatewayService) isOpenAIWSFallbackCooling(accountID int64) bool {
+ if s == nil || accountID <= 0 {
+ return false
+ }
+ cooldown := s.openAIWSFallbackCooldown()
+ if cooldown <= 0 {
+ return false
+ }
+ rawUntil, ok := s.openaiWSFallbackUntil.Load(accountID)
+ if !ok || rawUntil == nil {
+ return false
+ }
+ until, ok := rawUntil.(time.Time)
+ if !ok || until.IsZero() {
+ s.openaiWSFallbackUntil.Delete(accountID)
+ return false
+ }
+ if time.Now().Before(until) {
+ return true
+ }
+ s.openaiWSFallbackUntil.Delete(accountID)
+ return false
+}
+
+func (s *OpenAIGatewayService) markOpenAIWSFallbackCooling(accountID int64, _ string) {
+ if s == nil || accountID <= 0 {
+ return
+ }
+ cooldown := s.openAIWSFallbackCooldown()
+ if cooldown <= 0 {
+ return
+ }
+ s.openaiWSFallbackUntil.Store(accountID, time.Now().Add(cooldown))
+}
+
+func (s *OpenAIGatewayService) clearOpenAIWSFallbackCooling(accountID int64) {
+ if s == nil || accountID <= 0 {
+ return
+ }
+ s.openaiWSFallbackUntil.Delete(accountID)
+}
diff --git a/backend/internal/service/openai_ws_forwarder_benchmark_test.go b/backend/internal/service/openai_ws_forwarder_benchmark_test.go
new file mode 100644
index 00000000..bd03ab5a
--- /dev/null
+++ b/backend/internal/service/openai_ws_forwarder_benchmark_test.go
@@ -0,0 +1,127 @@
+package service
+
+import (
+ "fmt"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+)
+
+var (
+ benchmarkOpenAIWSPayloadJSONSink string
+ benchmarkOpenAIWSStringSink string
+ benchmarkOpenAIWSBoolSink bool
+ benchmarkOpenAIWSBytesSink []byte
+)
+
+func BenchmarkOpenAIWSForwarderHotPath(b *testing.B) {
+ cfg := &config.Config{}
+ svc := &OpenAIGatewayService{cfg: cfg}
+ account := &Account{ID: 1, Platform: PlatformOpenAI, Type: AccountTypeOAuth}
+ reqBody := benchmarkOpenAIWSHotPathRequest()
+
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ for i := 0; i < b.N; i++ {
+ payload := svc.buildOpenAIWSCreatePayload(reqBody, account)
+ _, _ = applyOpenAIWSRetryPayloadStrategy(payload, 2)
+ setOpenAIWSTurnMetadata(payload, `{"trace":"bench","turn":"1"}`)
+
+ benchmarkOpenAIWSStringSink = openAIWSPayloadString(payload, "previous_response_id")
+ benchmarkOpenAIWSBoolSink = payload["tools"] != nil
+ benchmarkOpenAIWSStringSink = summarizeOpenAIWSPayloadKeySizes(payload, openAIWSPayloadKeySizeTopN)
+ benchmarkOpenAIWSStringSink = summarizeOpenAIWSInput(payload["input"])
+ benchmarkOpenAIWSPayloadJSONSink = payloadAsJSON(payload)
+ }
+}
+
+func benchmarkOpenAIWSHotPathRequest() map[string]any {
+ tools := make([]map[string]any, 0, 24)
+ for i := 0; i < 24; i++ {
+ tools = append(tools, map[string]any{
+ "type": "function",
+ "name": fmt.Sprintf("tool_%02d", i),
+ "description": "benchmark tool schema",
+ "parameters": map[string]any{
+ "type": "object",
+ "properties": map[string]any{
+ "query": map[string]any{"type": "string"},
+ "limit": map[string]any{"type": "number"},
+ },
+ "required": []string{"query"},
+ },
+ })
+ }
+
+ input := make([]map[string]any, 0, 16)
+ for i := 0; i < 16; i++ {
+ input = append(input, map[string]any{
+ "role": "user",
+ "type": "message",
+ "content": fmt.Sprintf("benchmark message %d", i),
+ })
+ }
+
+ return map[string]any{
+ "type": "response.create",
+ "model": "gpt-5.3-codex",
+ "input": input,
+ "tools": tools,
+ "parallel_tool_calls": true,
+ "previous_response_id": "resp_benchmark_prev",
+ "prompt_cache_key": "bench-cache-key",
+ "reasoning": map[string]any{"effort": "medium"},
+ "instructions": "benchmark instructions",
+ "store": false,
+ }
+}
+
+func BenchmarkOpenAIWSEventEnvelopeParse(b *testing.B) {
+ event := []byte(`{"type":"response.completed","response":{"id":"resp_bench_1","model":"gpt-5.1","usage":{"input_tokens":12,"output_tokens":8}}}`)
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ for i := 0; i < b.N; i++ {
+ eventType, responseID, response := parseOpenAIWSEventEnvelope(event)
+ benchmarkOpenAIWSStringSink = eventType
+ benchmarkOpenAIWSStringSink = responseID
+ benchmarkOpenAIWSBoolSink = response.Exists()
+ }
+}
+
+func BenchmarkOpenAIWSErrorEventFieldReuse(b *testing.B) {
+ event := []byte(`{"type":"error","error":{"type":"invalid_request_error","code":"invalid_request","message":"invalid input"}}`)
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ for i := 0; i < b.N; i++ {
+ codeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(event)
+ benchmarkOpenAIWSStringSink, benchmarkOpenAIWSBoolSink = classifyOpenAIWSErrorEventFromRaw(codeRaw, errTypeRaw, errMsgRaw)
+ code, errType, errMsg := summarizeOpenAIWSErrorEventFieldsFromRaw(codeRaw, errTypeRaw, errMsgRaw)
+ benchmarkOpenAIWSStringSink = code
+ benchmarkOpenAIWSStringSink = errType
+ benchmarkOpenAIWSStringSink = errMsg
+ benchmarkOpenAIWSBoolSink = openAIWSErrorHTTPStatusFromRaw(codeRaw, errTypeRaw) > 0
+ }
+}
+
+func BenchmarkReplaceOpenAIWSMessageModel_NoMatchFastPath(b *testing.B) {
+ event := []byte(`{"type":"response.output_text.delta","delta":"hello world"}`)
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ for i := 0; i < b.N; i++ {
+ benchmarkOpenAIWSBytesSink = replaceOpenAIWSMessageModel(event, "gpt-5.1", "custom-model")
+ }
+}
+
+func BenchmarkReplaceOpenAIWSMessageModel_DualReplace(b *testing.B) {
+ event := []byte(`{"type":"response.completed","model":"gpt-5.1","response":{"id":"resp_1","model":"gpt-5.1"}}`)
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ for i := 0; i < b.N; i++ {
+ benchmarkOpenAIWSBytesSink = replaceOpenAIWSMessageModel(event, "gpt-5.1", "custom-model")
+ }
+}
diff --git a/backend/internal/service/openai_ws_forwarder_hotpath_optimization_test.go b/backend/internal/service/openai_ws_forwarder_hotpath_optimization_test.go
new file mode 100644
index 00000000..76167603
--- /dev/null
+++ b/backend/internal/service/openai_ws_forwarder_hotpath_optimization_test.go
@@ -0,0 +1,73 @@
+package service
+
+import (
+ "net/http"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestParseOpenAIWSEventEnvelope(t *testing.T) {
+ eventType, responseID, response := parseOpenAIWSEventEnvelope([]byte(`{"type":"response.completed","response":{"id":"resp_1","model":"gpt-5.1"}}`))
+ require.Equal(t, "response.completed", eventType)
+ require.Equal(t, "resp_1", responseID)
+ require.True(t, response.Exists())
+ require.Equal(t, `{"id":"resp_1","model":"gpt-5.1"}`, response.Raw)
+
+ eventType, responseID, response = parseOpenAIWSEventEnvelope([]byte(`{"type":"response.delta","id":"evt_1"}`))
+ require.Equal(t, "response.delta", eventType)
+ require.Equal(t, "evt_1", responseID)
+ require.False(t, response.Exists())
+}
+
+func TestParseOpenAIWSResponseUsageFromCompletedEvent(t *testing.T) {
+ usage := &OpenAIUsage{}
+ parseOpenAIWSResponseUsageFromCompletedEvent(
+ []byte(`{"type":"response.completed","response":{"usage":{"input_tokens":11,"output_tokens":7,"input_tokens_details":{"cached_tokens":3}}}}`),
+ usage,
+ )
+ require.Equal(t, 11, usage.InputTokens)
+ require.Equal(t, 7, usage.OutputTokens)
+ require.Equal(t, 3, usage.CacheReadInputTokens)
+}
+
+func TestOpenAIWSErrorEventHelpers_ConsistentWithWrapper(t *testing.T) {
+ message := []byte(`{"type":"error","error":{"type":"invalid_request_error","code":"invalid_request","message":"invalid input"}}`)
+ codeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(message)
+
+ wrappedReason, wrappedRecoverable := classifyOpenAIWSErrorEvent(message)
+ rawReason, rawRecoverable := classifyOpenAIWSErrorEventFromRaw(codeRaw, errTypeRaw, errMsgRaw)
+ require.Equal(t, wrappedReason, rawReason)
+ require.Equal(t, wrappedRecoverable, rawRecoverable)
+
+ wrappedStatus := openAIWSErrorHTTPStatus(message)
+ rawStatus := openAIWSErrorHTTPStatusFromRaw(codeRaw, errTypeRaw)
+ require.Equal(t, wrappedStatus, rawStatus)
+ require.Equal(t, http.StatusBadRequest, rawStatus)
+
+ wrappedCode, wrappedType, wrappedMsg := summarizeOpenAIWSErrorEventFields(message)
+ rawCode, rawType, rawMsg := summarizeOpenAIWSErrorEventFieldsFromRaw(codeRaw, errTypeRaw, errMsgRaw)
+ require.Equal(t, wrappedCode, rawCode)
+ require.Equal(t, wrappedType, rawType)
+ require.Equal(t, wrappedMsg, rawMsg)
+}
+
+func TestOpenAIWSMessageLikelyContainsToolCalls(t *testing.T) {
+ require.False(t, openAIWSMessageLikelyContainsToolCalls([]byte(`{"type":"response.output_text.delta","delta":"hello"}`)))
+ require.True(t, openAIWSMessageLikelyContainsToolCalls([]byte(`{"type":"response.output_item.added","item":{"tool_calls":[{"id":"tc1"}]}}`)))
+ require.True(t, openAIWSMessageLikelyContainsToolCalls([]byte(`{"type":"response.output_item.added","item":{"type":"function_call"}}`)))
+}
+
+func TestReplaceOpenAIWSMessageModel_OptimizedStillCorrect(t *testing.T) {
+ noModel := []byte(`{"type":"response.output_text.delta","delta":"hello"}`)
+ require.Equal(t, string(noModel), string(replaceOpenAIWSMessageModel(noModel, "gpt-5.1", "custom-model")))
+
+ rootOnly := []byte(`{"type":"response.created","model":"gpt-5.1"}`)
+ require.Equal(t, `{"type":"response.created","model":"custom-model"}`, string(replaceOpenAIWSMessageModel(rootOnly, "gpt-5.1", "custom-model")))
+
+ responseOnly := []byte(`{"type":"response.completed","response":{"model":"gpt-5.1"}}`)
+ require.Equal(t, `{"type":"response.completed","response":{"model":"custom-model"}}`, string(replaceOpenAIWSMessageModel(responseOnly, "gpt-5.1", "custom-model")))
+
+ both := []byte(`{"model":"gpt-5.1","response":{"model":"gpt-5.1"}}`)
+ require.Equal(t, `{"model":"custom-model","response":{"model":"custom-model"}}`, string(replaceOpenAIWSMessageModel(both, "gpt-5.1", "custom-model")))
+}
diff --git a/backend/internal/service/openai_ws_forwarder_ingress_session_test.go b/backend/internal/service/openai_ws_forwarder_ingress_session_test.go
new file mode 100644
index 00000000..5a3c12c3
--- /dev/null
+++ b/backend/internal/service/openai_ws_forwarder_ingress_session_test.go
@@ -0,0 +1,2483 @@
+package service
+
+import (
+ "context"
+ "errors"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "sync"
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ coderws "github.com/coder/websocket"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+ "github.com/tidwall/gjson"
+)
+
+func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_KeepLeaseAcrossTurns(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ cfg := &config.Config{}
+ cfg.Security.URLAllowlist.Enabled = false
+ cfg.Security.URLAllowlist.AllowInsecureHTTP = true
+ cfg.Gateway.OpenAIWS.Enabled = true
+ cfg.Gateway.OpenAIWS.OAuthEnabled = true
+ cfg.Gateway.OpenAIWS.APIKeyEnabled = true
+ cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
+ cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
+ cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
+ cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
+ cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
+ cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
+ cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
+ cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
+
+ captureConn := &openAIWSCaptureConn{
+ events: [][]byte{
+ []byte(`{"type":"response.completed","response":{"id":"resp_ingress_turn_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
+ []byte(`{"type":"response.completed","response":{"id":"resp_ingress_turn_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
+ },
+ }
+ captureDialer := &openAIWSCaptureDialer{conn: captureConn}
+ pool := newOpenAIWSConnPool(cfg)
+ pool.setClientDialerForTest(captureDialer)
+
+ svc := &OpenAIGatewayService{
+ cfg: cfg,
+ httpUpstream: &httpUpstreamRecorder{},
+ cache: &stubGatewayCache{},
+ openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
+ toolCorrector: NewCodexToolCorrector(),
+ openaiWSPool: pool,
+ }
+
+ account := &Account{
+ ID: 114,
+ Name: "openai-ingress-session-lease",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "api_key": "sk-test",
+ },
+ Extra: map[string]any{
+ "responses_websockets_v2_enabled": true,
+ },
+ }
+
+ serverErrCh := make(chan error, 1)
+ turnWSModeCh := make(chan bool, 2)
+ hooks := &OpenAIWSIngressHooks{
+ AfterTurn: func(_ int, result *OpenAIForwardResult, turnErr error) {
+ if turnErr == nil && result != nil {
+ turnWSModeCh <- result.OpenAIWSMode
+ }
+ },
+ }
+ wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
+ CompressionMode: coderws.CompressionContextTakeover,
+ })
+ if err != nil {
+ serverErrCh <- err
+ return
+ }
+ defer func() {
+ _ = conn.CloseNow()
+ }()
+
+ rec := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(rec)
+ req := r.Clone(r.Context())
+ req.Header = req.Header.Clone()
+ req.Header.Set("User-Agent", "unit-test-agent/1.0")
+ ginCtx.Request = req
+
+ readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
+ msgType, firstMessage, readErr := conn.Read(readCtx)
+ cancel()
+ if readErr != nil {
+ serverErrCh <- readErr
+ return
+ }
+ if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
+ serverErrCh <- errors.New("unsupported websocket client message type")
+ return
+ }
+
+ serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, hooks)
+ }))
+ defer wsServer.Close()
+
+ dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
+ clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
+ cancelDial()
+ require.NoError(t, err)
+ defer func() {
+ _ = clientConn.CloseNow()
+ }()
+
+ writeMessage := func(payload string) {
+ writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
+ defer cancel()
+ require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload)))
+ }
+ readMessage := func() []byte {
+ readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
+ defer cancel()
+ msgType, message, readErr := clientConn.Read(readCtx)
+ require.NoError(t, readErr)
+ require.Equal(t, coderws.MessageText, msgType)
+ return message
+ }
+
+ writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false}`)
+ firstTurnEvent := readMessage()
+ require.Equal(t, "response.completed", gjson.GetBytes(firstTurnEvent, "type").String())
+ require.Equal(t, "resp_ingress_turn_1", gjson.GetBytes(firstTurnEvent, "response.id").String())
+
+ writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"previous_response_id":"resp_ingress_turn_1"}`)
+ secondTurnEvent := readMessage()
+ require.Equal(t, "response.completed", gjson.GetBytes(secondTurnEvent, "type").String())
+ require.Equal(t, "resp_ingress_turn_2", gjson.GetBytes(secondTurnEvent, "response.id").String())
+ require.True(t, <-turnWSModeCh, "首轮 turn 应标记为 WS 模式")
+ require.True(t, <-turnWSModeCh, "第二轮 turn 应标记为 WS 模式")
+
+ require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done"))
+
+ select {
+ case serverErr := <-serverErrCh:
+ require.NoError(t, serverErr)
+ case <-time.After(5 * time.Second):
+ t.Fatal("等待 ingress websocket 结束超时")
+ }
+
+ metrics := svc.SnapshotOpenAIWSPoolMetrics()
+ require.Equal(t, int64(1), metrics.AcquireTotal, "同一 ingress 会话多 turn 应只获取一次上游 lease")
+ require.Equal(t, 1, captureDialer.DialCount(), "同一 ingress 会话应保持同一上游连接")
+ require.Len(t, captureConn.writes, 2, "应向同一上游连接发送两轮 response.create")
+}
+
+func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_DedicatedModeDoesNotReuseConnAcrossSessions(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ cfg := &config.Config{}
+ cfg.Security.URLAllowlist.Enabled = false
+ cfg.Security.URLAllowlist.AllowInsecureHTTP = true
+ cfg.Gateway.OpenAIWS.Enabled = true
+ cfg.Gateway.OpenAIWS.OAuthEnabled = true
+ cfg.Gateway.OpenAIWS.APIKeyEnabled = true
+ cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
+ cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true
+ cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeShared
+ cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
+ cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
+ cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
+ cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
+ cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
+ cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
+ cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
+
+ upstreamConn1 := &openAIWSCaptureConn{
+ events: [][]byte{
+ []byte(`{"type":"response.completed","response":{"id":"resp_dedicated_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
+ },
+ }
+ upstreamConn2 := &openAIWSCaptureConn{
+ events: [][]byte{
+ []byte(`{"type":"response.completed","response":{"id":"resp_dedicated_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
+ },
+ }
+ dialer := &openAIWSQueueDialer{
+ conns: []openAIWSClientConn{upstreamConn1, upstreamConn2},
+ }
+ pool := newOpenAIWSConnPool(cfg)
+ pool.setClientDialerForTest(dialer)
+
+ svc := &OpenAIGatewayService{
+ cfg: cfg,
+ httpUpstream: &httpUpstreamRecorder{},
+ cache: &stubGatewayCache{},
+ openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
+ toolCorrector: NewCodexToolCorrector(),
+ openaiWSPool: pool,
+ }
+
+ account := &Account{
+ ID: 441,
+ Name: "openai-ingress-dedicated",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "api_key": "sk-test",
+ },
+ Extra: map[string]any{
+ "openai_apikey_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated,
+ },
+ }
+
+ serverErrCh := make(chan error, 2)
+ wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
+ CompressionMode: coderws.CompressionContextTakeover,
+ })
+ if err != nil {
+ serverErrCh <- err
+ return
+ }
+ defer func() {
+ _ = conn.CloseNow()
+ }()
+
+ rec := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(rec)
+ req := r.Clone(r.Context())
+ req.Header = req.Header.Clone()
+ req.Header.Set("User-Agent", "unit-test-agent/1.0")
+ ginCtx.Request = req
+
+ readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
+ msgType, firstMessage, readErr := conn.Read(readCtx)
+ cancel()
+ if readErr != nil {
+ serverErrCh <- readErr
+ return
+ }
+ if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
+ serverErrCh <- errors.New("unsupported websocket client message type")
+ return
+ }
+
+ serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil)
+ }))
+ defer wsServer.Close()
+
+ runSingleTurnSession := func(expectedResponseID string) {
+ dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
+ clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
+ cancelDial()
+ require.NoError(t, err)
+ defer func() {
+ _ = clientConn.CloseNow()
+ }()
+
+ writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
+ err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.1","stream":false}`))
+ cancelWrite()
+ require.NoError(t, err)
+
+ readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second)
+ msgType, event, readErr := clientConn.Read(readCtx)
+ cancelRead()
+ require.NoError(t, readErr)
+ require.Equal(t, coderws.MessageText, msgType)
+ require.Equal(t, expectedResponseID, gjson.GetBytes(event, "response.id").String())
+
+ require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done"))
+
+ select {
+ case serverErr := <-serverErrCh:
+ require.NoError(t, serverErr)
+ case <-time.After(5 * time.Second):
+ t.Fatal("等待 ingress websocket 结束超时")
+ }
+ }
+
+ runSingleTurnSession("resp_dedicated_1")
+ runSingleTurnSession("resp_dedicated_2")
+
+ require.Equal(t, 2, dialer.DialCount(), "dedicated 模式下跨客户端会话不应复用上游连接")
+}
+
+func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_ModeOffReturnsPolicyViolation(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ cfg := &config.Config{}
+ cfg.Security.URLAllowlist.Enabled = false
+ cfg.Security.URLAllowlist.AllowInsecureHTTP = true
+ cfg.Gateway.OpenAIWS.Enabled = true
+ cfg.Gateway.OpenAIWS.OAuthEnabled = true
+ cfg.Gateway.OpenAIWS.APIKeyEnabled = true
+ cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
+ cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true
+ cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeShared
+
+ svc := &OpenAIGatewayService{
+ cfg: cfg,
+ httpUpstream: &httpUpstreamRecorder{},
+ cache: &stubGatewayCache{},
+ openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
+ toolCorrector: NewCodexToolCorrector(),
+ openaiWSPool: newOpenAIWSConnPool(cfg),
+ }
+
+ account := &Account{
+ ID: 442,
+ Name: "openai-ingress-off",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "api_key": "sk-test",
+ },
+ Extra: map[string]any{
+ "openai_apikey_responses_websockets_v2_mode": OpenAIWSIngressModeOff,
+ },
+ }
+
+ serverErrCh := make(chan error, 1)
+ wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
+ CompressionMode: coderws.CompressionContextTakeover,
+ })
+ if err != nil {
+ serverErrCh <- err
+ return
+ }
+ defer func() {
+ _ = conn.CloseNow()
+ }()
+
+ rec := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(rec)
+ req := r.Clone(r.Context())
+ req.Header = req.Header.Clone()
+ req.Header.Set("User-Agent", "unit-test-agent/1.0")
+ ginCtx.Request = req
+
+ readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
+ msgType, firstMessage, readErr := conn.Read(readCtx)
+ cancel()
+ if readErr != nil {
+ serverErrCh <- readErr
+ return
+ }
+ if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
+ serverErrCh <- errors.New("unsupported websocket client message type")
+ return
+ }
+
+ serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil)
+ }))
+ defer wsServer.Close()
+
+ dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
+ clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
+ cancelDial()
+ require.NoError(t, err)
+ defer func() {
+ _ = clientConn.CloseNow()
+ }()
+
+ writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
+ err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.1","stream":false}`))
+ cancelWrite()
+ require.NoError(t, err)
+
+ select {
+ case serverErr := <-serverErrCh:
+ var closeErr *OpenAIWSClientCloseError
+ require.ErrorAs(t, serverErr, &closeErr)
+ require.Equal(t, coderws.StatusPolicyViolation, closeErr.StatusCode())
+ require.Equal(t, "websocket mode is disabled for this account", closeErr.Reason())
+ case <-time.After(5 * time.Second):
+ t.Fatal("等待 ingress websocket 结束超时")
+ }
+}
+
+func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledPrevResponseStrictDropToFullCreate(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ cfg := &config.Config{}
+ cfg.Security.URLAllowlist.Enabled = false
+ cfg.Security.URLAllowlist.AllowInsecureHTTP = true
+ cfg.Gateway.OpenAIWS.Enabled = true
+ cfg.Gateway.OpenAIWS.OAuthEnabled = true
+ cfg.Gateway.OpenAIWS.APIKeyEnabled = true
+ cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
+ cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
+ cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
+ cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
+ cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
+ cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
+ cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
+ cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
+
+ captureConn := &openAIWSCaptureConn{
+ events: [][]byte{
+ []byte(`{"type":"response.completed","response":{"id":"resp_preflight_rewrite_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
+ []byte(`{"type":"response.completed","response":{"id":"resp_preflight_rewrite_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
+ },
+ }
+ captureDialer := &openAIWSCaptureDialer{conn: captureConn}
+ pool := newOpenAIWSConnPool(cfg)
+ pool.setClientDialerForTest(captureDialer)
+
+ svc := &OpenAIGatewayService{
+ cfg: cfg,
+ httpUpstream: &httpUpstreamRecorder{},
+ cache: &stubGatewayCache{},
+ openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
+ toolCorrector: NewCodexToolCorrector(),
+ openaiWSPool: pool,
+ }
+
+ account := &Account{
+ ID: 140,
+ Name: "openai-ingress-prev-preflight-rewrite",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "api_key": "sk-test",
+ },
+ Extra: map[string]any{
+ "responses_websockets_v2_enabled": true,
+ },
+ }
+
+ serverErrCh := make(chan error, 1)
+ wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
+ CompressionMode: coderws.CompressionContextTakeover,
+ })
+ if err != nil {
+ serverErrCh <- err
+ return
+ }
+ defer func() {
+ _ = conn.CloseNow()
+ }()
+
+ rec := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(rec)
+ req := r.Clone(r.Context())
+ req.Header = req.Header.Clone()
+ req.Header.Set("User-Agent", "unit-test-agent/1.0")
+ ginCtx.Request = req
+
+ readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
+ msgType, firstMessage, readErr := conn.Read(readCtx)
+ cancel()
+ if readErr != nil {
+ serverErrCh <- readErr
+ return
+ }
+ if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
+ serverErrCh <- errors.New("unsupported websocket client message type")
+ return
+ }
+
+ serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil)
+ }))
+ defer wsServer.Close()
+
+ dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
+ clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
+ cancelDial()
+ require.NoError(t, err)
+ defer func() {
+ _ = clientConn.CloseNow()
+ }()
+
+ writeMessage := func(payload string) {
+ writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
+ defer cancel()
+ require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload)))
+ }
+ readMessage := func() []byte {
+ readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
+ defer cancel()
+ msgType, message, readErr := clientConn.Read(readCtx)
+ require.NoError(t, readErr)
+ require.Equal(t, coderws.MessageText, msgType)
+ return message
+ }
+
+ writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"input_text","text":"hello"}]}`)
+ firstTurn := readMessage()
+ require.Equal(t, "resp_preflight_rewrite_1", gjson.GetBytes(firstTurn, "response.id").String())
+
+ writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"previous_response_id":"resp_stale_external","input":[{"type":"input_text","text":"world"}]}`)
+ secondTurn := readMessage()
+ require.Equal(t, "resp_preflight_rewrite_2", gjson.GetBytes(secondTurn, "response.id").String())
+
+ require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done"))
+ select {
+ case serverErr := <-serverErrCh:
+ require.NoError(t, serverErr)
+ case <-time.After(5 * time.Second):
+ t.Fatal("等待 ingress websocket 结束超时")
+ }
+
+ require.Equal(t, 1, captureDialer.DialCount(), "严格增量不成立时应在同一连接内降级为 full create")
+ require.Len(t, captureConn.writes, 2)
+ secondWrite := requestToJSONString(captureConn.writes[1])
+ require.False(t, gjson.Get(secondWrite, "previous_response_id").Exists(), "严格增量不成立时应移除 previous_response_id,改为 full create")
+ require.Equal(t, 2, len(gjson.Get(secondWrite, "input").Array()), "严格降级为 full create 时应重放完整 input 上下文")
+ require.Equal(t, "hello", gjson.Get(secondWrite, "input.0.text").String())
+ require.Equal(t, "world", gjson.Get(secondWrite, "input.1.text").String())
+}
+
+func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledPrevResponseStrictDropBeforePreflightPingFailReconnects(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ prevPreflightPingIdle := openAIWSIngressPreflightPingIdle
+ openAIWSIngressPreflightPingIdle = 0
+ defer func() {
+ openAIWSIngressPreflightPingIdle = prevPreflightPingIdle
+ }()
+
+ cfg := &config.Config{}
+ cfg.Security.URLAllowlist.Enabled = false
+ cfg.Security.URLAllowlist.AllowInsecureHTTP = true
+ cfg.Gateway.OpenAIWS.Enabled = true
+ cfg.Gateway.OpenAIWS.OAuthEnabled = true
+ cfg.Gateway.OpenAIWS.APIKeyEnabled = true
+ cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
+ cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2
+ cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
+ cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 2
+ cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
+ cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
+ cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
+ cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
+
+ firstConn := &openAIWSPreflightFailConn{
+ events: [][]byte{
+ []byte(`{"type":"response.completed","response":{"id":"resp_turn_ping_drop_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
+ },
+ }
+ secondConn := &openAIWSCaptureConn{
+ events: [][]byte{
+ []byte(`{"type":"response.completed","response":{"id":"resp_turn_ping_drop_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
+ },
+ }
+ dialer := &openAIWSQueueDialer{
+ conns: []openAIWSClientConn{firstConn, secondConn},
+ }
+ pool := newOpenAIWSConnPool(cfg)
+ pool.setClientDialerForTest(dialer)
+
+ svc := &OpenAIGatewayService{
+ cfg: cfg,
+ httpUpstream: &httpUpstreamRecorder{},
+ cache: &stubGatewayCache{},
+ openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
+ toolCorrector: NewCodexToolCorrector(),
+ openaiWSPool: pool,
+ }
+
+ account := &Account{
+ ID: 142,
+ Name: "openai-ingress-prev-strict-drop-before-ping",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "api_key": "sk-test",
+ },
+ Extra: map[string]any{
+ "responses_websockets_v2_enabled": true,
+ },
+ }
+
+ serverErrCh := make(chan error, 1)
+ wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
+ CompressionMode: coderws.CompressionContextTakeover,
+ })
+ if err != nil {
+ serverErrCh <- err
+ return
+ }
+ defer func() {
+ _ = conn.CloseNow()
+ }()
+
+ rec := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(rec)
+ req := r.Clone(r.Context())
+ req.Header = req.Header.Clone()
+ req.Header.Set("User-Agent", "unit-test-agent/1.0")
+ ginCtx.Request = req
+
+ readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
+ msgType, firstMessage, readErr := conn.Read(readCtx)
+ cancel()
+ if readErr != nil {
+ serverErrCh <- readErr
+ return
+ }
+ if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
+ serverErrCh <- errors.New("unsupported websocket client message type")
+ return
+ }
+
+ serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil)
+ }))
+ defer wsServer.Close()
+
+ dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
+ clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
+ cancelDial()
+ require.NoError(t, err)
+ defer func() {
+ _ = clientConn.CloseNow()
+ }()
+
+ writeMessage := func(payload string) {
+ writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
+ defer cancel()
+ require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload)))
+ }
+ readMessage := func() []byte {
+ readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
+ defer cancel()
+ msgType, message, readErr := clientConn.Read(readCtx)
+ require.NoError(t, readErr)
+ require.Equal(t, coderws.MessageText, msgType)
+ return message
+ }
+
+ writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"input_text","text":"hello"}]}`)
+ firstTurn := readMessage()
+ require.Equal(t, "resp_turn_ping_drop_1", gjson.GetBytes(firstTurn, "response.id").String())
+
+ writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"previous_response_id":"resp_stale_external","input":[{"type":"input_text","text":"world"}]}`)
+ secondTurn := readMessage()
+ require.Equal(t, "resp_turn_ping_drop_2", gjson.GetBytes(secondTurn, "response.id").String())
+
+ require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done"))
+ select {
+ case serverErr := <-serverErrCh:
+ require.NoError(t, serverErr)
+ case <-time.After(5 * time.Second):
+ t.Fatal("等待 ingress websocket 严格降级后预检换连超时")
+ }
+
+ require.Equal(t, 2, dialer.DialCount(), "严格降级为 full create 后,预检 ping 失败应允许换连")
+ require.Equal(t, 1, firstConn.WriteCount(), "首连接在预检失败后不应继续发送第二轮")
+ require.GreaterOrEqual(t, firstConn.PingCount(), 1, "第二轮前应执行 preflight ping")
+ secondConn.mu.Lock()
+ secondWrites := append([]map[string]any(nil), secondConn.writes...)
+ secondConn.mu.Unlock()
+ require.Len(t, secondWrites, 1)
+ secondWrite := requestToJSONString(secondWrites[0])
+ require.False(t, gjson.Get(secondWrite, "previous_response_id").Exists(), "严格降级后重试应移除 previous_response_id")
+ require.Equal(t, 2, len(gjson.Get(secondWrite, "input").Array()))
+ require.Equal(t, "hello", gjson.Get(secondWrite, "input.0.text").String())
+ require.Equal(t, "world", gjson.Get(secondWrite, "input.1.text").String())
+}
+
+func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreEnabledSkipsStrictPrevResponseEval(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ cfg := &config.Config{}
+ cfg.Security.URLAllowlist.Enabled = false
+ cfg.Security.URLAllowlist.AllowInsecureHTTP = true
+ cfg.Gateway.OpenAIWS.Enabled = true
+ cfg.Gateway.OpenAIWS.OAuthEnabled = true
+ cfg.Gateway.OpenAIWS.APIKeyEnabled = true
+ cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
+ cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
+ cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
+ cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
+ cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
+ cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
+ cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
+ cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
+
+ captureConn := &openAIWSCaptureConn{
+ events: [][]byte{
+ []byte(`{"type":"response.completed","response":{"id":"resp_store_enabled_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
+ []byte(`{"type":"response.completed","response":{"id":"resp_store_enabled_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
+ },
+ }
+ captureDialer := &openAIWSCaptureDialer{conn: captureConn}
+ pool := newOpenAIWSConnPool(cfg)
+ pool.setClientDialerForTest(captureDialer)
+
+ svc := &OpenAIGatewayService{
+ cfg: cfg,
+ httpUpstream: &httpUpstreamRecorder{},
+ cache: &stubGatewayCache{},
+ openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
+ toolCorrector: NewCodexToolCorrector(),
+ openaiWSPool: pool,
+ }
+
+ account := &Account{
+ ID: 143,
+ Name: "openai-ingress-store-enabled-skip-strict",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "api_key": "sk-test",
+ },
+ Extra: map[string]any{
+ "responses_websockets_v2_enabled": true,
+ },
+ }
+
+ serverErrCh := make(chan error, 1)
+ wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
+ CompressionMode: coderws.CompressionContextTakeover,
+ })
+ if err != nil {
+ serverErrCh <- err
+ return
+ }
+ defer func() {
+ _ = conn.CloseNow()
+ }()
+
+ rec := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(rec)
+ req := r.Clone(r.Context())
+ req.Header = req.Header.Clone()
+ req.Header.Set("User-Agent", "unit-test-agent/1.0")
+ ginCtx.Request = req
+
+ readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
+ msgType, firstMessage, readErr := conn.Read(readCtx)
+ cancel()
+ if readErr != nil {
+ serverErrCh <- readErr
+ return
+ }
+ if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
+ serverErrCh <- errors.New("unsupported websocket client message type")
+ return
+ }
+
+ serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil)
+ }))
+ defer wsServer.Close()
+
+ dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
+ clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
+ cancelDial()
+ require.NoError(t, err)
+ defer func() {
+ _ = clientConn.CloseNow()
+ }()
+
+ writeMessage := func(payload string) {
+ writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
+ defer cancel()
+ require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload)))
+ }
+ readMessage := func() []byte {
+ readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
+ defer cancel()
+ msgType, message, readErr := clientConn.Read(readCtx)
+ require.NoError(t, readErr)
+ require.Equal(t, coderws.MessageText, msgType)
+ return message
+ }
+
+ writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":true}`)
+ firstTurn := readMessage()
+ require.Equal(t, "resp_store_enabled_1", gjson.GetBytes(firstTurn, "response.id").String())
+
+ writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":true,"previous_response_id":"resp_stale_external"}`)
+ secondTurn := readMessage()
+ require.Equal(t, "resp_store_enabled_2", gjson.GetBytes(secondTurn, "response.id").String())
+
+ require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done"))
+ select {
+ case serverErr := <-serverErrCh:
+ require.NoError(t, serverErr)
+ case <-time.After(5 * time.Second):
+ t.Fatal("等待 store=true 场景 websocket 结束超时")
+ }
+
+ require.Equal(t, 1, captureDialer.DialCount())
+ require.Len(t, captureConn.writes, 2)
+ require.Equal(t, "resp_stale_external", gjson.Get(requestToJSONString(captureConn.writes[1]), "previous_response_id").String(), "store=true 场景不应触发 store-disabled strict 规则")
+}
+
+func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledPrevResponsePreflightSkipForFunctionCallOutput(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ cfg := &config.Config{}
+ cfg.Security.URLAllowlist.Enabled = false
+ cfg.Security.URLAllowlist.AllowInsecureHTTP = true
+ cfg.Gateway.OpenAIWS.Enabled = true
+ cfg.Gateway.OpenAIWS.OAuthEnabled = true
+ cfg.Gateway.OpenAIWS.APIKeyEnabled = true
+ cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
+ cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
+ cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
+ cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
+ cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
+ cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
+ cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
+ cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
+
+ captureConn := &openAIWSCaptureConn{
+ events: [][]byte{
+ []byte(`{"type":"response.completed","response":{"id":"resp_preflight_skip_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
+ []byte(`{"type":"response.completed","response":{"id":"resp_preflight_skip_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
+ },
+ }
+ captureDialer := &openAIWSCaptureDialer{conn: captureConn}
+ pool := newOpenAIWSConnPool(cfg)
+ pool.setClientDialerForTest(captureDialer)
+
+ svc := &OpenAIGatewayService{
+ cfg: cfg,
+ httpUpstream: &httpUpstreamRecorder{},
+ cache: &stubGatewayCache{},
+ openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
+ toolCorrector: NewCodexToolCorrector(),
+ openaiWSPool: pool,
+ }
+
+ account := &Account{
+ ID: 141,
+ Name: "openai-ingress-prev-preflight-skip-fco",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "api_key": "sk-test",
+ },
+ Extra: map[string]any{
+ "responses_websockets_v2_enabled": true,
+ },
+ }
+
+ serverErrCh := make(chan error, 1)
+ wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
+ CompressionMode: coderws.CompressionContextTakeover,
+ })
+ if err != nil {
+ serverErrCh <- err
+ return
+ }
+ defer func() {
+ _ = conn.CloseNow()
+ }()
+
+ rec := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(rec)
+ req := r.Clone(r.Context())
+ req.Header = req.Header.Clone()
+ req.Header.Set("User-Agent", "unit-test-agent/1.0")
+ ginCtx.Request = req
+
+ readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
+ msgType, firstMessage, readErr := conn.Read(readCtx)
+ cancel()
+ if readErr != nil {
+ serverErrCh <- readErr
+ return
+ }
+ if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
+ serverErrCh <- errors.New("unsupported websocket client message type")
+ return
+ }
+
+ serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil)
+ }))
+ defer wsServer.Close()
+
+ dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
+ clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
+ cancelDial()
+ require.NoError(t, err)
+ defer func() {
+ _ = clientConn.CloseNow()
+ }()
+
+ writeMessage := func(payload string) {
+ writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
+ defer cancel()
+ require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload)))
+ }
+ readMessage := func() []byte {
+ readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
+ defer cancel()
+ msgType, message, readErr := clientConn.Read(readCtx)
+ require.NoError(t, readErr)
+ require.Equal(t, coderws.MessageText, msgType)
+ return message
+ }
+
+ writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false}`)
+ firstTurn := readMessage()
+ require.Equal(t, "resp_preflight_skip_1", gjson.GetBytes(firstTurn, "response.id").String())
+
+ writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"previous_response_id":"resp_stale_external","input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]}`)
+ secondTurn := readMessage()
+ require.Equal(t, "resp_preflight_skip_2", gjson.GetBytes(secondTurn, "response.id").String())
+
+ require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done"))
+ select {
+ case serverErr := <-serverErrCh:
+ require.NoError(t, serverErr)
+ case <-time.After(5 * time.Second):
+ t.Fatal("等待 ingress websocket 结束超时")
+ }
+
+ require.Equal(t, 1, captureDialer.DialCount())
+ require.Len(t, captureConn.writes, 2)
+ require.Equal(t, "resp_stale_external", gjson.Get(requestToJSONString(captureConn.writes[1]), "previous_response_id").String(), "function_call_output 场景不应预改写 previous_response_id")
+}
+
+func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledFunctionCallOutputAutoAttachPreviousResponseID(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ cfg := &config.Config{}
+ cfg.Security.URLAllowlist.Enabled = false
+ cfg.Security.URLAllowlist.AllowInsecureHTTP = true
+ cfg.Gateway.OpenAIWS.Enabled = true
+ cfg.Gateway.OpenAIWS.OAuthEnabled = true
+ cfg.Gateway.OpenAIWS.APIKeyEnabled = true
+ cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
+ cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
+ cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
+ cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
+ cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
+ cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
+ cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
+ cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
+
+ captureConn := &openAIWSCaptureConn{
+ events: [][]byte{
+ []byte(`{"type":"response.completed","response":{"id":"resp_auto_prev_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
+ []byte(`{"type":"response.completed","response":{"id":"resp_auto_prev_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
+ },
+ }
+ captureDialer := &openAIWSCaptureDialer{conn: captureConn}
+ pool := newOpenAIWSConnPool(cfg)
+ pool.setClientDialerForTest(captureDialer)
+
+ svc := &OpenAIGatewayService{
+ cfg: cfg,
+ httpUpstream: &httpUpstreamRecorder{},
+ cache: &stubGatewayCache{},
+ openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
+ toolCorrector: NewCodexToolCorrector(),
+ openaiWSPool: pool,
+ }
+
+ account := &Account{
+ ID: 143,
+ Name: "openai-ingress-fco-auto-prev",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "api_key": "sk-test",
+ },
+ Extra: map[string]any{
+ "responses_websockets_v2_enabled": true,
+ },
+ }
+
+ serverErrCh := make(chan error, 1)
+ wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
+ CompressionMode: coderws.CompressionContextTakeover,
+ })
+ if err != nil {
+ serverErrCh <- err
+ return
+ }
+ defer func() {
+ _ = conn.CloseNow()
+ }()
+
+ rec := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(rec)
+ req := r.Clone(r.Context())
+ req.Header = req.Header.Clone()
+ req.Header.Set("User-Agent", "unit-test-agent/1.0")
+ ginCtx.Request = req
+
+ readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
+ msgType, firstMessage, readErr := conn.Read(readCtx)
+ cancel()
+ if readErr != nil {
+ serverErrCh <- readErr
+ return
+ }
+ if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
+ serverErrCh <- errors.New("unsupported websocket client message type")
+ return
+ }
+
+ serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil)
+ }))
+ defer wsServer.Close()
+
+ dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
+ clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
+ cancelDial()
+ require.NoError(t, err)
+ defer func() {
+ _ = clientConn.CloseNow()
+ }()
+
+ writeMessage := func(payload string) {
+ writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
+ defer cancel()
+ require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload)))
+ }
+ readMessage := func() []byte {
+ readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
+ defer cancel()
+ msgType, message, readErr := clientConn.Read(readCtx)
+ require.NoError(t, readErr)
+ require.Equal(t, coderws.MessageText, msgType)
+ return message
+ }
+
+ writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"input_text","text":"hello"}]}`)
+ firstTurn := readMessage()
+ require.Equal(t, "resp_auto_prev_1", gjson.GetBytes(firstTurn, "response.id").String())
+
+ writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"function_call_output","call_id":"call_auto_1","output":"ok"}]}`)
+ secondTurn := readMessage()
+ require.Equal(t, "resp_auto_prev_2", gjson.GetBytes(secondTurn, "response.id").String())
+
+ require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done"))
+ select {
+ case serverErr := <-serverErrCh:
+ require.NoError(t, serverErr)
+ case <-time.After(5 * time.Second):
+ t.Fatal("等待 ingress websocket 结束超时")
+ }
+
+ require.Equal(t, 1, captureDialer.DialCount())
+ require.Len(t, captureConn.writes, 2)
+ require.Equal(t, "resp_auto_prev_1", gjson.Get(requestToJSONString(captureConn.writes[1]), "previous_response_id").String(), "function_call_output 缺失 previous_response_id 时应回填上一轮响应 ID")
+}
+
+func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledFunctionCallOutputSkipsAutoAttachWhenLastResponseIDMissing(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ cfg := &config.Config{}
+ cfg.Security.URLAllowlist.Enabled = false
+ cfg.Security.URLAllowlist.AllowInsecureHTTP = true
+ cfg.Gateway.OpenAIWS.Enabled = true
+ cfg.Gateway.OpenAIWS.OAuthEnabled = true
+ cfg.Gateway.OpenAIWS.APIKeyEnabled = true
+ cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
+ cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
+ cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
+ cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
+ cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
+ cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
+ cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
+ cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
+
+ captureConn := &openAIWSCaptureConn{
+ events: [][]byte{
+ []byte(`{"type":"response.completed","response":{"model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
+ []byte(`{"type":"response.completed","response":{"id":"resp_auto_prev_skip_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
+ },
+ }
+ captureDialer := &openAIWSCaptureDialer{conn: captureConn}
+ pool := newOpenAIWSConnPool(cfg)
+ pool.setClientDialerForTest(captureDialer)
+
+ svc := &OpenAIGatewayService{
+ cfg: cfg,
+ httpUpstream: &httpUpstreamRecorder{},
+ cache: &stubGatewayCache{},
+ openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
+ toolCorrector: NewCodexToolCorrector(),
+ openaiWSPool: pool,
+ }
+
+ account := &Account{
+ ID: 144,
+ Name: "openai-ingress-fco-auto-prev-skip",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "api_key": "sk-test",
+ },
+ Extra: map[string]any{
+ "responses_websockets_v2_enabled": true,
+ },
+ }
+
+ serverErrCh := make(chan error, 1)
+ wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
+ CompressionMode: coderws.CompressionContextTakeover,
+ })
+ if err != nil {
+ serverErrCh <- err
+ return
+ }
+ defer func() {
+ _ = conn.CloseNow()
+ }()
+
+ rec := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(rec)
+ req := r.Clone(r.Context())
+ req.Header = req.Header.Clone()
+ req.Header.Set("User-Agent", "unit-test-agent/1.0")
+ ginCtx.Request = req
+
+ readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
+ msgType, firstMessage, readErr := conn.Read(readCtx)
+ cancel()
+ if readErr != nil {
+ serverErrCh <- readErr
+ return
+ }
+ if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
+ serverErrCh <- errors.New("unsupported websocket client message type")
+ return
+ }
+
+ serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil)
+ }))
+ defer wsServer.Close()
+
+ dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
+ clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
+ cancelDial()
+ require.NoError(t, err)
+ defer func() {
+ _ = clientConn.CloseNow()
+ }()
+
+ writeMessage := func(payload string) {
+ writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
+ defer cancel()
+ require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload)))
+ }
+ readMessage := func() []byte {
+ readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
+ defer cancel()
+ msgType, message, readErr := clientConn.Read(readCtx)
+ require.NoError(t, readErr)
+ require.Equal(t, coderws.MessageText, msgType)
+ return message
+ }
+
+ writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"input_text","text":"hello"}]}`)
+ firstTurn := readMessage()
+ require.Equal(t, "response.completed", gjson.GetBytes(firstTurn, "type").String())
+ require.Empty(t, gjson.GetBytes(firstTurn, "response.id").String(), "首轮响应不返回 response.id,模拟无法推导续链锚点")
+
+ writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"function_call_output","call_id":"call_auto_skip_1","output":"ok"}]}`)
+ secondTurn := readMessage()
+ require.Equal(t, "resp_auto_prev_skip_2", gjson.GetBytes(secondTurn, "response.id").String())
+
+ require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done"))
+ select {
+ case serverErr := <-serverErrCh:
+ require.NoError(t, serverErr)
+ case <-time.After(5 * time.Second):
+ t.Fatal("等待 ingress websocket 结束超时")
+ }
+
+ require.Equal(t, 1, captureDialer.DialCount())
+ require.Len(t, captureConn.writes, 2)
+ require.False(t, gjson.Get(requestToJSONString(captureConn.writes[1]), "previous_response_id").Exists(), "上一轮缺失 response.id 时不应自动补齐 previous_response_id")
+}
+
+func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PreflightPingFailReconnectsBeforeTurn(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ prevPreflightPingIdle := openAIWSIngressPreflightPingIdle
+ openAIWSIngressPreflightPingIdle = 0
+ defer func() {
+ openAIWSIngressPreflightPingIdle = prevPreflightPingIdle
+ }()
+
+ cfg := &config.Config{}
+ cfg.Security.URLAllowlist.Enabled = false
+ cfg.Security.URLAllowlist.AllowInsecureHTTP = true
+ cfg.Gateway.OpenAIWS.Enabled = true
+ cfg.Gateway.OpenAIWS.OAuthEnabled = true
+ cfg.Gateway.OpenAIWS.APIKeyEnabled = true
+ cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
+ cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
+ cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
+ cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
+ cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
+ cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
+ cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
+ cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
+
+ firstConn := &openAIWSPreflightFailConn{
+ events: [][]byte{
+ []byte(`{"type":"response.completed","response":{"id":"resp_turn_ping_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
+ },
+ }
+ secondConn := &openAIWSCaptureConn{
+ events: [][]byte{
+ []byte(`{"type":"response.completed","response":{"id":"resp_turn_ping_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
+ },
+ }
+ dialer := &openAIWSQueueDialer{
+ conns: []openAIWSClientConn{firstConn, secondConn},
+ }
+ pool := newOpenAIWSConnPool(cfg)
+ pool.setClientDialerForTest(dialer)
+
+ svc := &OpenAIGatewayService{
+ cfg: cfg,
+ httpUpstream: &httpUpstreamRecorder{},
+ cache: &stubGatewayCache{},
+ openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
+ toolCorrector: NewCodexToolCorrector(),
+ openaiWSPool: pool,
+ }
+
+ account := &Account{
+ ID: 116,
+ Name: "openai-ingress-preflight-ping",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "api_key": "sk-test",
+ },
+ Extra: map[string]any{
+ "responses_websockets_v2_enabled": true,
+ },
+ }
+
+ serverErrCh := make(chan error, 1)
+ wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
+ CompressionMode: coderws.CompressionContextTakeover,
+ })
+ if err != nil {
+ serverErrCh <- err
+ return
+ }
+ defer func() {
+ _ = conn.CloseNow()
+ }()
+
+ rec := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(rec)
+ req := r.Clone(r.Context())
+ req.Header = req.Header.Clone()
+ req.Header.Set("User-Agent", "unit-test-agent/1.0")
+ ginCtx.Request = req
+
+ readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
+ msgType, firstMessage, readErr := conn.Read(readCtx)
+ cancel()
+ if readErr != nil {
+ serverErrCh <- readErr
+ return
+ }
+ if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
+ serverErrCh <- errors.New("unsupported websocket client message type")
+ return
+ }
+
+ serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil)
+ }))
+ defer wsServer.Close()
+
+ dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
+ clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
+ cancelDial()
+ require.NoError(t, err)
+ defer func() {
+ _ = clientConn.CloseNow()
+ }()
+
+ writeMessage := func(payload string) {
+ writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
+ defer cancel()
+ require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload)))
+ }
+ readMessage := func() []byte {
+ readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
+ defer cancel()
+ msgType, message, readErr := clientConn.Read(readCtx)
+ require.NoError(t, readErr)
+ require.Equal(t, coderws.MessageText, msgType)
+ return message
+ }
+
+ writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false}`)
+ firstTurn := readMessage()
+ require.Equal(t, "resp_turn_ping_1", gjson.GetBytes(firstTurn, "response.id").String())
+
+ writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"previous_response_id":"resp_turn_ping_1"}`)
+ secondTurn := readMessage()
+ require.Equal(t, "resp_turn_ping_2", gjson.GetBytes(secondTurn, "response.id").String())
+
+ require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done"))
+ select {
+ case serverErr := <-serverErrCh:
+ require.NoError(t, serverErr)
+ case <-time.After(5 * time.Second):
+ t.Fatal("等待 ingress websocket 结束超时")
+ }
+ require.Equal(t, 2, dialer.DialCount(), "第二轮 turn 前 ping 失败应触发换连")
+ require.Equal(t, 1, firstConn.WriteCount(), "preflight ping 失败后不应继续向旧连接发送第二轮 turn")
+ require.GreaterOrEqual(t, firstConn.PingCount(), 1, "第二轮前应对旧连接执行 preflight ping")
+}
+
+func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledStrictAffinityPreflightPingFailAutoRecoveryReconnects(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ prevPreflightPingIdle := openAIWSIngressPreflightPingIdle
+ openAIWSIngressPreflightPingIdle = 0
+ defer func() {
+ openAIWSIngressPreflightPingIdle = prevPreflightPingIdle
+ }()
+
+ cfg := &config.Config{}
+ cfg.Security.URLAllowlist.Enabled = false
+ cfg.Security.URLAllowlist.AllowInsecureHTTP = true
+ cfg.Gateway.OpenAIWS.Enabled = true
+ cfg.Gateway.OpenAIWS.OAuthEnabled = true
+ cfg.Gateway.OpenAIWS.APIKeyEnabled = true
+ cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
+ cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2
+ cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
+ cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 2
+ cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
+ cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
+ cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
+ cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
+
+ firstConn := &openAIWSPreflightFailConn{
+ events: [][]byte{
+ []byte(`{"type":"response.completed","response":{"id":"resp_turn_ping_strict_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
+ },
+ }
+ secondConn := &openAIWSCaptureConn{
+ events: [][]byte{
+ []byte(`{"type":"response.completed","response":{"id":"resp_turn_ping_strict_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
+ },
+ }
+ dialer := &openAIWSQueueDialer{
+ conns: []openAIWSClientConn{firstConn, secondConn},
+ }
+ pool := newOpenAIWSConnPool(cfg)
+ pool.setClientDialerForTest(dialer)
+
+ svc := &OpenAIGatewayService{
+ cfg: cfg,
+ httpUpstream: &httpUpstreamRecorder{},
+ cache: &stubGatewayCache{},
+ openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
+ toolCorrector: NewCodexToolCorrector(),
+ openaiWSPool: pool,
+ }
+
+ account := &Account{
+ ID: 121,
+ Name: "openai-ingress-preflight-ping-strict-affinity",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "api_key": "sk-test",
+ },
+ Extra: map[string]any{
+ "responses_websockets_v2_enabled": true,
+ },
+ }
+
+ serverErrCh := make(chan error, 1)
+ wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
+ CompressionMode: coderws.CompressionContextTakeover,
+ })
+ if err != nil {
+ serverErrCh <- err
+ return
+ }
+ defer func() {
+ _ = conn.CloseNow()
+ }()
+
+ rec := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(rec)
+ req := r.Clone(r.Context())
+ req.Header = req.Header.Clone()
+ req.Header.Set("User-Agent", "unit-test-agent/1.0")
+ ginCtx.Request = req
+
+ readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
+ msgType, firstMessage, readErr := conn.Read(readCtx)
+ cancel()
+ if readErr != nil {
+ serverErrCh <- readErr
+ return
+ }
+ if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
+ serverErrCh <- errors.New("unsupported websocket client message type")
+ return
+ }
+
+ serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil)
+ }))
+ defer wsServer.Close()
+
+ dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
+ clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
+ cancelDial()
+ require.NoError(t, err)
+ defer func() {
+ _ = clientConn.CloseNow()
+ }()
+
+ writeMessage := func(payload string) {
+ writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
+ defer cancel()
+ require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload)))
+ }
+ readMessage := func() []byte {
+ readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
+ defer cancel()
+ msgType, message, readErr := clientConn.Read(readCtx)
+ require.NoError(t, readErr)
+ require.Equal(t, coderws.MessageText, msgType)
+ return message
+ }
+
+ writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"input_text","text":"hello"}]}`)
+ firstTurn := readMessage()
+ require.Equal(t, "resp_turn_ping_strict_1", gjson.GetBytes(firstTurn, "response.id").String())
+
+ writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"previous_response_id":"resp_turn_ping_strict_1","input":[{"type":"input_text","text":"world"}]}`)
+ secondTurn := readMessage()
+ require.Equal(t, "resp_turn_ping_strict_2", gjson.GetBytes(secondTurn, "response.id").String())
+
+ require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done"))
+ select {
+ case serverErr := <-serverErrCh:
+ require.NoError(t, serverErr)
+ case <-time.After(5 * time.Second):
+ t.Fatal("等待 ingress websocket 严格亲和自动恢复后结束超时")
+ }
+
+ require.Equal(t, 2, dialer.DialCount(), "严格亲和 preflight ping 失败后应自动降级并换连重放")
+ require.Equal(t, 1, firstConn.WriteCount(), "preflight ping 失败后不应继续在旧连接写第二轮")
+ require.GreaterOrEqual(t, firstConn.PingCount(), 1, "第二轮前应执行 preflight ping")
+ secondConn.mu.Lock()
+ secondWrites := append([]map[string]any(nil), secondConn.writes...)
+ secondConn.mu.Unlock()
+ require.Len(t, secondWrites, 1)
+ secondWrite := requestToJSONString(secondWrites[0])
+ require.False(t, gjson.Get(secondWrite, "previous_response_id").Exists(), "自动恢复重放应移除 previous_response_id")
+ require.Equal(t, 2, len(gjson.Get(secondWrite, "input").Array()), "自动恢复重放应使用完整 input 上下文")
+ require.Equal(t, "hello", gjson.Get(secondWrite, "input.0.text").String())
+ require.Equal(t, "world", gjson.Get(secondWrite, "input.1.text").String())
+}
+
+func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_WriteFailBeforeDownstreamRetriesOnce(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ cfg := &config.Config{}
+ cfg.Security.URLAllowlist.Enabled = false
+ cfg.Security.URLAllowlist.AllowInsecureHTTP = true
+ cfg.Gateway.OpenAIWS.Enabled = true
+ cfg.Gateway.OpenAIWS.OAuthEnabled = true
+ cfg.Gateway.OpenAIWS.APIKeyEnabled = true
+ cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
+ cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
+ cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
+ cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
+ cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
+ cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
+ cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
+ cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
+
+ firstConn := &openAIWSWriteFailAfterFirstTurnConn{
+ events: [][]byte{
+ []byte(`{"type":"response.completed","response":{"id":"resp_turn_write_retry_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
+ },
+ }
+ secondConn := &openAIWSCaptureConn{
+ events: [][]byte{
+ []byte(`{"type":"response.completed","response":{"id":"resp_turn_write_retry_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
+ },
+ }
+ dialer := &openAIWSQueueDialer{
+ conns: []openAIWSClientConn{firstConn, secondConn},
+ }
+ pool := newOpenAIWSConnPool(cfg)
+ pool.setClientDialerForTest(dialer)
+
+ svc := &OpenAIGatewayService{
+ cfg: cfg,
+ httpUpstream: &httpUpstreamRecorder{},
+ cache: &stubGatewayCache{},
+ openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
+ toolCorrector: NewCodexToolCorrector(),
+ openaiWSPool: pool,
+ }
+
+ account := &Account{
+ ID: 117,
+ Name: "openai-ingress-write-retry",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "api_key": "sk-test",
+ },
+ Extra: map[string]any{
+ "responses_websockets_v2_enabled": true,
+ },
+ }
+ var hooksMu sync.Mutex
+ beforeTurnCalls := make(map[int]int)
+ afterTurnCalls := make(map[int]int)
+ hooks := &OpenAIWSIngressHooks{
+ BeforeTurn: func(turn int) error {
+ hooksMu.Lock()
+ beforeTurnCalls[turn]++
+ hooksMu.Unlock()
+ return nil
+ },
+ AfterTurn: func(turn int, _ *OpenAIForwardResult, _ error) {
+ hooksMu.Lock()
+ afterTurnCalls[turn]++
+ hooksMu.Unlock()
+ },
+ }
+
+ serverErrCh := make(chan error, 1)
+ wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
+ CompressionMode: coderws.CompressionContextTakeover,
+ })
+ if err != nil {
+ serverErrCh <- err
+ return
+ }
+ defer func() {
+ _ = conn.CloseNow()
+ }()
+
+ rec := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(rec)
+ req := r.Clone(r.Context())
+ req.Header = req.Header.Clone()
+ req.Header.Set("User-Agent", "unit-test-agent/1.0")
+ ginCtx.Request = req
+
+ readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
+ msgType, firstMessage, readErr := conn.Read(readCtx)
+ cancel()
+ if readErr != nil {
+ serverErrCh <- readErr
+ return
+ }
+ if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
+ serverErrCh <- errors.New("unsupported websocket client message type")
+ return
+ }
+
+ serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, hooks)
+ }))
+ defer wsServer.Close()
+
+ dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
+ clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
+ cancelDial()
+ require.NoError(t, err)
+ defer func() {
+ _ = clientConn.CloseNow()
+ }()
+
+ writeMessage := func(payload string) {
+ writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
+ defer cancel()
+ require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload)))
+ }
+ readMessage := func() []byte {
+ readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
+ defer cancel()
+ msgType, message, readErr := clientConn.Read(readCtx)
+ require.NoError(t, readErr)
+ require.Equal(t, coderws.MessageText, msgType)
+ return message
+ }
+
+ writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false}`)
+ firstTurn := readMessage()
+ require.Equal(t, "resp_turn_write_retry_1", gjson.GetBytes(firstTurn, "response.id").String())
+
+ writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"previous_response_id":"resp_turn_write_retry_1"}`)
+ secondTurn := readMessage()
+ require.Equal(t, "resp_turn_write_retry_2", gjson.GetBytes(secondTurn, "response.id").String())
+
+ require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done"))
+ select {
+ case serverErr := <-serverErrCh:
+ require.NoError(t, serverErr)
+ case <-time.After(5 * time.Second):
+ t.Fatal("等待 ingress websocket 结束超时")
+ }
+ require.Equal(t, 2, dialer.DialCount(), "第二轮 turn 上游写失败且未写下游时应自动重试并换连")
+ hooksMu.Lock()
+ beforeTurn1 := beforeTurnCalls[1]
+ beforeTurn2 := beforeTurnCalls[2]
+ afterTurn1 := afterTurnCalls[1]
+ afterTurn2 := afterTurnCalls[2]
+ hooksMu.Unlock()
+ require.Equal(t, 1, beforeTurn1, "首轮 turn BeforeTurn 应执行一次")
+ require.Equal(t, 1, beforeTurn2, "同一 turn 重试不应重复触发 BeforeTurn")
+ require.Equal(t, 1, afterTurn1, "首轮 turn AfterTurn 应执行一次")
+ require.Equal(t, 1, afterTurn2, "第二轮 turn AfterTurn 应执行一次")
+}
+
+func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PreviousResponseNotFoundRecoversByDroppingPrevID(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ cfg := &config.Config{}
+ cfg.Security.URLAllowlist.Enabled = false
+ cfg.Security.URLAllowlist.AllowInsecureHTTP = true
+ cfg.Gateway.OpenAIWS.Enabled = true
+ cfg.Gateway.OpenAIWS.OAuthEnabled = true
+ cfg.Gateway.OpenAIWS.APIKeyEnabled = true
+ cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
+ cfg.Gateway.OpenAIWS.IngressPreviousResponseRecoveryEnabled = true
+ cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
+ cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
+ cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
+ cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
+ cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
+ cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
+ cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
+
+ firstConn := &openAIWSCaptureConn{
+ events: [][]byte{
+ []byte(`{"type":"response.completed","response":{"id":"resp_turn_prev_recover_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
+ []byte(`{"type":"error","error":{"type":"invalid_request_error","code":"previous_response_not_found","message":""}}`),
+ },
+ }
+ secondConn := &openAIWSCaptureConn{
+ events: [][]byte{
+ []byte(`{"type":"response.completed","response":{"id":"resp_turn_prev_recover_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
+ },
+ }
+ dialer := &openAIWSQueueDialer{
+ conns: []openAIWSClientConn{firstConn, secondConn},
+ }
+ pool := newOpenAIWSConnPool(cfg)
+ pool.setClientDialerForTest(dialer)
+
+ svc := &OpenAIGatewayService{
+ cfg: cfg,
+ httpUpstream: &httpUpstreamRecorder{},
+ cache: &stubGatewayCache{},
+ openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
+ toolCorrector: NewCodexToolCorrector(),
+ openaiWSPool: pool,
+ }
+
+ account := &Account{
+ ID: 118,
+ Name: "openai-ingress-prev-recovery",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "api_key": "sk-test",
+ },
+ Extra: map[string]any{
+ "responses_websockets_v2_enabled": true,
+ },
+ }
+
+ serverErrCh := make(chan error, 1)
+ wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
+ CompressionMode: coderws.CompressionContextTakeover,
+ })
+ if err != nil {
+ serverErrCh <- err
+ return
+ }
+ defer func() {
+ _ = conn.CloseNow()
+ }()
+
+ rec := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(rec)
+ req := r.Clone(r.Context())
+ req.Header = req.Header.Clone()
+ req.Header.Set("User-Agent", "unit-test-agent/1.0")
+ ginCtx.Request = req
+
+ readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
+ msgType, firstMessage, readErr := conn.Read(readCtx)
+ cancel()
+ if readErr != nil {
+ serverErrCh <- readErr
+ return
+ }
+ if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
+ serverErrCh <- errors.New("unsupported websocket client message type")
+ return
+ }
+
+ serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil)
+ }))
+ defer wsServer.Close()
+
+ dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
+ clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
+ cancelDial()
+ require.NoError(t, err)
+ defer func() {
+ _ = clientConn.CloseNow()
+ }()
+
+ writeMessage := func(payload string) {
+ writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
+ defer cancel()
+ require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload)))
+ }
+ readMessage := func() []byte {
+ readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
+ defer cancel()
+ msgType, message, readErr := clientConn.Read(readCtx)
+ require.NoError(t, readErr)
+ require.Equal(t, coderws.MessageText, msgType)
+ return message
+ }
+
+ writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"previous_response_id":"resp_seed_anchor"}`)
+ firstTurn := readMessage()
+ require.Equal(t, "resp_turn_prev_recover_1", gjson.GetBytes(firstTurn, "response.id").String())
+
+ writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"previous_response_id":"resp_turn_prev_recover_1"}`)
+ secondTurn := readMessage()
+ require.Equal(t, "response.completed", gjson.GetBytes(secondTurn, "type").String())
+ require.Equal(t, "resp_turn_prev_recover_2", gjson.GetBytes(secondTurn, "response.id").String())
+
+ require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done"))
+ select {
+ case serverErr := <-serverErrCh:
+ require.NoError(t, serverErr)
+ case <-time.After(5 * time.Second):
+ t.Fatal("等待 ingress websocket 结束超时")
+ }
+
+ require.Equal(t, 2, dialer.DialCount(), "previous_response_not_found 恢复应触发换连重试")
+
+ firstConn.mu.Lock()
+ firstWrites := append([]map[string]any(nil), firstConn.writes...)
+ firstConn.mu.Unlock()
+ require.Len(t, firstWrites, 2, "首个连接应处理首轮与失败的第二轮请求")
+ require.True(t, gjson.Get(requestToJSONString(firstWrites[1]), "previous_response_id").Exists(), "失败轮次首发请求应包含 previous_response_id")
+
+ secondConn.mu.Lock()
+ secondWrites := append([]map[string]any(nil), secondConn.writes...)
+ secondConn.mu.Unlock()
+ require.Len(t, secondWrites, 1, "恢复重试应在第二个连接发送一次请求")
+ require.False(t, gjson.Get(requestToJSONString(secondWrites[0]), "previous_response_id").Exists(), "恢复重试应移除 previous_response_id")
+}
+
+func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledStrictAffinityPreviousResponseNotFoundLayer2Recovery(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ cfg := &config.Config{}
+ cfg.Security.URLAllowlist.Enabled = false
+ cfg.Security.URLAllowlist.AllowInsecureHTTP = true
+ cfg.Gateway.OpenAIWS.Enabled = true
+ cfg.Gateway.OpenAIWS.OAuthEnabled = true
+ cfg.Gateway.OpenAIWS.APIKeyEnabled = true
+ cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
+ cfg.Gateway.OpenAIWS.IngressPreviousResponseRecoveryEnabled = true
+ cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
+ cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
+ cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
+ cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
+ cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
+ cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
+ cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
+
+ firstConn := &openAIWSCaptureConn{
+ events: [][]byte{
+ []byte(`{"type":"response.completed","response":{"id":"resp_turn_prev_strict_recover_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
+ []byte(`{"type":"error","error":{"type":"invalid_request_error","code":"previous_response_not_found","message":"missing strict anchor"}}`),
+ },
+ }
+ secondConn := &openAIWSCaptureConn{
+ events: [][]byte{
+ []byte(`{"type":"response.completed","response":{"id":"resp_turn_prev_strict_recover_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
+ },
+ }
+ dialer := &openAIWSQueueDialer{
+ conns: []openAIWSClientConn{firstConn, secondConn},
+ }
+ pool := newOpenAIWSConnPool(cfg)
+ pool.setClientDialerForTest(dialer)
+
+ svc := &OpenAIGatewayService{
+ cfg: cfg,
+ httpUpstream: &httpUpstreamRecorder{},
+ cache: &stubGatewayCache{},
+ openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
+ toolCorrector: NewCodexToolCorrector(),
+ openaiWSPool: pool,
+ }
+
+ account := &Account{
+ ID: 122,
+ Name: "openai-ingress-prev-strict-layer2",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "api_key": "sk-test",
+ },
+ Extra: map[string]any{
+ "responses_websockets_v2_enabled": true,
+ },
+ }
+
+ serverErrCh := make(chan error, 1)
+ wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
+ CompressionMode: coderws.CompressionContextTakeover,
+ })
+ if err != nil {
+ serverErrCh <- err
+ return
+ }
+ defer func() {
+ _ = conn.CloseNow()
+ }()
+
+ rec := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(rec)
+ req := r.Clone(r.Context())
+ req.Header = req.Header.Clone()
+ req.Header.Set("User-Agent", "unit-test-agent/1.0")
+ ginCtx.Request = req
+
+ readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
+ msgType, firstMessage, readErr := conn.Read(readCtx)
+ cancel()
+ if readErr != nil {
+ serverErrCh <- readErr
+ return
+ }
+ if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
+ serverErrCh <- errors.New("unsupported websocket client message type")
+ return
+ }
+
+ serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil)
+ }))
+ defer wsServer.Close()
+
+ dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
+ clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
+ cancelDial()
+ require.NoError(t, err)
+ defer func() {
+ _ = clientConn.CloseNow()
+ }()
+
+ writeMessage := func(payload string) {
+ writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
+ defer cancel()
+ require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload)))
+ }
+ readMessage := func() []byte {
+ readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
+ defer cancel()
+ msgType, message, readErr := clientConn.Read(readCtx)
+ require.NoError(t, readErr)
+ require.Equal(t, coderws.MessageText, msgType)
+ return message
+ }
+
+ writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"prompt_cache_key":"pk_strict_layer2","input":[{"type":"input_text","text":"hello"}]}`)
+ firstTurn := readMessage()
+ require.Equal(t, "resp_turn_prev_strict_recover_1", gjson.GetBytes(firstTurn, "response.id").String())
+
+ writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"prompt_cache_key":"pk_strict_layer2","previous_response_id":"resp_turn_prev_strict_recover_1","input":[{"type":"input_text","text":"world"}]}`)
+ secondTurn := readMessage()
+ require.Equal(t, "resp_turn_prev_strict_recover_2", gjson.GetBytes(secondTurn, "response.id").String())
+
+ require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done"))
+ select {
+ case serverErr := <-serverErrCh:
+ require.NoError(t, serverErr)
+ case <-time.After(5 * time.Second):
+ t.Fatal("等待 ingress websocket 严格亲和 Layer2 恢复结束超时")
+ }
+
+ require.Equal(t, 2, dialer.DialCount(), "严格亲和链路命中 previous_response_not_found 应触发 Layer2 恢复重试")
+
+ firstConn.mu.Lock()
+ firstWrites := append([]map[string]any(nil), firstConn.writes...)
+ firstConn.mu.Unlock()
+ require.Len(t, firstWrites, 2, "首连接应收到首轮请求和失败的续链请求")
+ require.True(t, gjson.Get(requestToJSONString(firstWrites[1]), "previous_response_id").Exists())
+
+ secondConn.mu.Lock()
+ secondWrites := append([]map[string]any(nil), secondConn.writes...)
+ secondConn.mu.Unlock()
+ require.Len(t, secondWrites, 1, "Layer2 恢复应仅重放一次")
+ secondWrite := requestToJSONString(secondWrites[0])
+ require.False(t, gjson.Get(secondWrite, "previous_response_id").Exists(), "Layer2 恢复重放应移除 previous_response_id")
+ require.True(t, gjson.Get(secondWrite, "store").Exists(), "Layer2 恢复不应改变 store 标志")
+ require.False(t, gjson.Get(secondWrite, "store").Bool())
+ require.Equal(t, 2, len(gjson.Get(secondWrite, "input").Array()), "Layer2 恢复应重放完整 input 上下文")
+ require.Equal(t, "hello", gjson.Get(secondWrite, "input.0.text").String())
+ require.Equal(t, "world", gjson.Get(secondWrite, "input.1.text").String())
+}
+
+func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PreviousResponseNotFoundRecoveryRemovesDuplicatePrevID(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ cfg := &config.Config{}
+ cfg.Security.URLAllowlist.Enabled = false
+ cfg.Security.URLAllowlist.AllowInsecureHTTP = true
+ cfg.Gateway.OpenAIWS.Enabled = true
+ cfg.Gateway.OpenAIWS.OAuthEnabled = true
+ cfg.Gateway.OpenAIWS.APIKeyEnabled = true
+ cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
+ cfg.Gateway.OpenAIWS.IngressPreviousResponseRecoveryEnabled = true
+ cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
+ cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
+ cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
+ cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
+ cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
+ cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
+ cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
+
+ firstConn := &openAIWSCaptureConn{
+ events: [][]byte{
+ []byte(`{"type":"response.completed","response":{"id":"resp_turn_prev_once_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
+ []byte(`{"type":"error","error":{"type":"invalid_request_error","code":"previous_response_not_found","message":"first missing"}}`),
+ },
+ }
+ secondConn := &openAIWSCaptureConn{
+ events: [][]byte{
+ []byte(`{"type":"response.completed","response":{"id":"resp_turn_prev_once_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
+ },
+ }
+ dialer := &openAIWSQueueDialer{
+ conns: []openAIWSClientConn{firstConn, secondConn},
+ }
+ pool := newOpenAIWSConnPool(cfg)
+ pool.setClientDialerForTest(dialer)
+
+ svc := &OpenAIGatewayService{
+ cfg: cfg,
+ httpUpstream: &httpUpstreamRecorder{},
+ cache: &stubGatewayCache{},
+ openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
+ toolCorrector: NewCodexToolCorrector(),
+ openaiWSPool: pool,
+ }
+
+ account := &Account{
+ ID: 120,
+ Name: "openai-ingress-prev-recovery-once",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "api_key": "sk-test",
+ },
+ Extra: map[string]any{
+ "responses_websockets_v2_enabled": true,
+ },
+ }
+
+ serverErrCh := make(chan error, 1)
+ wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
+ CompressionMode: coderws.CompressionContextTakeover,
+ })
+ if err != nil {
+ serverErrCh <- err
+ return
+ }
+ defer func() {
+ _ = conn.CloseNow()
+ }()
+
+ rec := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(rec)
+ req := r.Clone(r.Context())
+ req.Header = req.Header.Clone()
+ req.Header.Set("User-Agent", "unit-test-agent/1.0")
+ ginCtx.Request = req
+
+ readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
+ msgType, firstMessage, readErr := conn.Read(readCtx)
+ cancel()
+ if readErr != nil {
+ serverErrCh <- readErr
+ return
+ }
+ if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
+ serverErrCh <- errors.New("unsupported websocket client message type")
+ return
+ }
+
+ serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil)
+ }))
+ defer wsServer.Close()
+
+ dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
+ clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
+ cancelDial()
+ require.NoError(t, err)
+ defer func() {
+ _ = clientConn.CloseNow()
+ }()
+
+ writeMessage := func(payload string) {
+ writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
+ defer cancel()
+ require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload)))
+ }
+ readMessage := func() []byte {
+ readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
+ defer cancel()
+ msgType, message, readErr := clientConn.Read(readCtx)
+ require.NoError(t, readErr)
+ require.Equal(t, coderws.MessageText, msgType)
+ return message
+ }
+
+ writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false}`)
+ firstTurn := readMessage()
+ require.Equal(t, "resp_turn_prev_once_1", gjson.GetBytes(firstTurn, "response.id").String())
+
+ // duplicate previous_response_id: 恢复重试时应删除所有重复键,避免再次 previous_response_not_found。
+ writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"previous_response_id":"resp_turn_prev_once_1","input":[],"previous_response_id":"resp_turn_prev_duplicate"}`)
+ secondTurn := readMessage()
+ require.Equal(t, "resp_turn_prev_once_2", gjson.GetBytes(secondTurn, "response.id").String())
+
+ require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done"))
+ select {
+ case serverErr := <-serverErrCh:
+ require.NoError(t, serverErr)
+ case <-time.After(5 * time.Second):
+ t.Fatal("等待 ingress websocket 结束超时")
+ }
+
+ require.Equal(t, 2, dialer.DialCount(), "previous_response_not_found 恢复应只重试一次")
+
+ firstConn.mu.Lock()
+ firstWrites := append([]map[string]any(nil), firstConn.writes...)
+ firstConn.mu.Unlock()
+ require.Len(t, firstWrites, 2)
+ require.True(t, gjson.Get(requestToJSONString(firstWrites[1]), "previous_response_id").Exists())
+
+ secondConn.mu.Lock()
+ secondWrites := append([]map[string]any(nil), secondConn.writes...)
+ secondConn.mu.Unlock()
+ require.Len(t, secondWrites, 1)
+ require.False(t, gjson.Get(requestToJSONString(secondWrites[0]), "previous_response_id").Exists(), "重复键场景恢复重试后不应保留 previous_response_id")
+}
+
+func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_RejectsMessageIDAsPreviousResponseID(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ cfg := &config.Config{}
+ cfg.Security.URLAllowlist.Enabled = false
+ cfg.Security.URLAllowlist.AllowInsecureHTTP = true
+ cfg.Gateway.OpenAIWS.Enabled = true
+ cfg.Gateway.OpenAIWS.OAuthEnabled = true
+ cfg.Gateway.OpenAIWS.APIKeyEnabled = true
+ cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
+
+ svc := &OpenAIGatewayService{
+ cfg: cfg,
+ httpUpstream: &httpUpstreamRecorder{},
+ cache: &stubGatewayCache{},
+ openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
+ toolCorrector: NewCodexToolCorrector(),
+ }
+
+ account := &Account{
+ ID: 119,
+ Name: "openai-ingress-prev-validation",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "api_key": "sk-test",
+ },
+ Extra: map[string]any{
+ "responses_websockets_v2_enabled": true,
+ },
+ }
+
+ serverErrCh := make(chan error, 1)
+ wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
+ CompressionMode: coderws.CompressionContextTakeover,
+ })
+ if err != nil {
+ serverErrCh <- err
+ return
+ }
+ defer func() {
+ _ = conn.CloseNow()
+ }()
+
+ rec := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(rec)
+ req := r.Clone(r.Context())
+ req.Header = req.Header.Clone()
+ req.Header.Set("User-Agent", "unit-test-agent/1.0")
+ ginCtx.Request = req
+
+ readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
+ msgType, firstMessage, readErr := conn.Read(readCtx)
+ cancel()
+ if readErr != nil {
+ serverErrCh <- readErr
+ return
+ }
+ if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
+ serverErrCh <- errors.New("unsupported websocket client message type")
+ return
+ }
+
+ serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil)
+ }))
+ defer wsServer.Close()
+
+ dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
+ clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
+ cancelDial()
+ require.NoError(t, err)
+ defer func() {
+ _ = clientConn.CloseNow()
+ }()
+
+ writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
+ err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.1","stream":false,"previous_response_id":"msg_123456"}`))
+ cancelWrite()
+ require.NoError(t, err)
+
+ select {
+ case serverErr := <-serverErrCh:
+ require.Error(t, serverErr)
+ var closeErr *OpenAIWSClientCloseError
+ require.ErrorAs(t, serverErr, &closeErr)
+ require.Equal(t, coderws.StatusPolicyViolation, closeErr.StatusCode())
+ require.Contains(t, closeErr.Reason(), "previous_response_id must be a response.id")
+ case <-time.After(5 * time.Second):
+ t.Fatal("等待 ingress websocket 结束超时")
+ }
+}
+
+type openAIWSQueueDialer struct {
+ mu sync.Mutex
+ conns []openAIWSClientConn
+ dialCount int
+}
+
+func (d *openAIWSQueueDialer) Dial(
+ ctx context.Context,
+ wsURL string,
+ headers http.Header,
+ proxyURL string,
+) (openAIWSClientConn, int, http.Header, error) {
+ _ = ctx
+ _ = wsURL
+ _ = headers
+ _ = proxyURL
+ d.mu.Lock()
+ defer d.mu.Unlock()
+ d.dialCount++
+ if len(d.conns) == 0 {
+ return nil, 503, nil, errors.New("no test conn")
+ }
+ conn := d.conns[0]
+ if len(d.conns) > 1 {
+ d.conns = d.conns[1:]
+ }
+ return conn, 0, nil, nil
+}
+
+func (d *openAIWSQueueDialer) DialCount() int {
+ d.mu.Lock()
+ defer d.mu.Unlock()
+ return d.dialCount
+}
+
+type openAIWSPreflightFailConn struct {
+ mu sync.Mutex
+ events [][]byte
+ pingFails bool
+ writeCount int
+ pingCount int
+}
+
+func (c *openAIWSPreflightFailConn) WriteJSON(context.Context, any) error {
+ c.mu.Lock()
+ c.writeCount++
+ c.mu.Unlock()
+ return nil
+}
+
+func (c *openAIWSPreflightFailConn) ReadMessage(context.Context) ([]byte, error) {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ if len(c.events) == 0 {
+ return nil, io.EOF
+ }
+ event := c.events[0]
+ c.events = c.events[1:]
+ if len(c.events) == 0 {
+ c.pingFails = true
+ }
+ return event, nil
+}
+
+func (c *openAIWSPreflightFailConn) Ping(context.Context) error {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ c.pingCount++
+ if c.pingFails {
+ return errors.New("preflight ping failed")
+ }
+ return nil
+}
+
+func (c *openAIWSPreflightFailConn) Close() error {
+ return nil
+}
+
+func (c *openAIWSPreflightFailConn) WriteCount() int {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ return c.writeCount
+}
+
+func (c *openAIWSPreflightFailConn) PingCount() int {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ return c.pingCount
+}
+
+type openAIWSWriteFailAfterFirstTurnConn struct {
+ mu sync.Mutex
+ events [][]byte
+ failOnWrite bool
+}
+
+func (c *openAIWSWriteFailAfterFirstTurnConn) WriteJSON(context.Context, any) error {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ if c.failOnWrite {
+ return errors.New("write failed on stale conn")
+ }
+ return nil
+}
+
+func (c *openAIWSWriteFailAfterFirstTurnConn) ReadMessage(context.Context) ([]byte, error) {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ if len(c.events) == 0 {
+ return nil, io.EOF
+ }
+ event := c.events[0]
+ c.events = c.events[1:]
+ if len(c.events) == 0 {
+ c.failOnWrite = true
+ }
+ return event, nil
+}
+
+func (c *openAIWSWriteFailAfterFirstTurnConn) Ping(context.Context) error {
+ return nil
+}
+
+func (c *openAIWSWriteFailAfterFirstTurnConn) Close() error {
+ return nil
+}
+
+func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_ClientDisconnectStillDrainsUpstream(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ cfg := &config.Config{}
+ cfg.Security.URLAllowlist.Enabled = false
+ cfg.Security.URLAllowlist.AllowInsecureHTTP = true
+ cfg.Gateway.OpenAIWS.Enabled = true
+ cfg.Gateway.OpenAIWS.OAuthEnabled = true
+ cfg.Gateway.OpenAIWS.APIKeyEnabled = true
+ cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
+ cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
+ cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
+ cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
+ cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
+ cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
+ cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
+ cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
+
+ // 多个上游事件:前几个为非 terminal 事件,最后一个为 terminal。
+ // 第一个事件延迟 250ms 让客户端 RST 有时间传播,使 writeClientMessage 可靠失败。
+ captureConn := &openAIWSCaptureConn{
+ readDelays: []time.Duration{250 * time.Millisecond, 0, 0},
+ events: [][]byte{
+ []byte(`{"type":"response.created","response":{"id":"resp_ingress_disconnect","model":"gpt-5.1"}}`),
+ []byte(`{"type":"response.output_item.added","response":{"id":"resp_ingress_disconnect"}}`),
+ []byte(`{"type":"response.completed","response":{"id":"resp_ingress_disconnect","model":"gpt-5.1","usage":{"input_tokens":2,"output_tokens":1}}}`),
+ },
+ }
+ captureDialer := &openAIWSCaptureDialer{conn: captureConn}
+ pool := newOpenAIWSConnPool(cfg)
+ pool.setClientDialerForTest(captureDialer)
+
+ svc := &OpenAIGatewayService{
+ cfg: cfg,
+ httpUpstream: &httpUpstreamRecorder{},
+ cache: &stubGatewayCache{},
+ openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
+ toolCorrector: NewCodexToolCorrector(),
+ openaiWSPool: pool,
+ }
+
+ account := &Account{
+ ID: 115,
+ Name: "openai-ingress-client-disconnect",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "api_key": "sk-test",
+ "model_mapping": map[string]any{
+ "custom-original-model": "gpt-5.1",
+ },
+ },
+ Extra: map[string]any{
+ "responses_websockets_v2_enabled": true,
+ },
+ }
+
+ serverErrCh := make(chan error, 1)
+ resultCh := make(chan *OpenAIForwardResult, 1)
+ hooks := &OpenAIWSIngressHooks{
+ AfterTurn: func(_ int, result *OpenAIForwardResult, turnErr error) {
+ if turnErr == nil && result != nil {
+ resultCh <- result
+ }
+ },
+ }
+ wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
+ CompressionMode: coderws.CompressionContextTakeover,
+ })
+ if err != nil {
+ serverErrCh <- err
+ return
+ }
+ defer func() {
+ _ = conn.CloseNow()
+ }()
+
+ rec := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(rec)
+ req := r.Clone(r.Context())
+ req.Header = req.Header.Clone()
+ req.Header.Set("User-Agent", "unit-test-agent/1.0")
+ ginCtx.Request = req
+
+ readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
+ msgType, firstMessage, readErr := conn.Read(readCtx)
+ cancel()
+ if readErr != nil {
+ serverErrCh <- readErr
+ return
+ }
+ if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
+ serverErrCh <- errors.New("unsupported websocket client message type")
+ return
+ }
+
+ serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, hooks)
+ }))
+ defer wsServer.Close()
+
+ dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
+ clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
+ cancelDial()
+ require.NoError(t, err)
+
+ writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
+ err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"custom-original-model","stream":false}`))
+ cancelWrite()
+ require.NoError(t, err)
+ // 立即关闭客户端,模拟客户端在 relay 期间断连。
+ require.NoError(t, clientConn.CloseNow(), "模拟 ingress 客户端提前断连")
+
+ select {
+ case serverErr := <-serverErrCh:
+ require.NoError(t, serverErr, "客户端断连后应继续 drain 上游直到 terminal 或正常结束")
+ case <-time.After(5 * time.Second):
+ t.Fatal("等待 ingress websocket 结束超时")
+ }
+
+ select {
+ case result := <-resultCh:
+ require.Equal(t, "resp_ingress_disconnect", result.RequestID)
+ require.Equal(t, 2, result.Usage.InputTokens)
+ require.Equal(t, 1, result.Usage.OutputTokens)
+ case <-time.After(2 * time.Second):
+ t.Fatal("未收到断连后的 turn 结果回调")
+ }
+}
diff --git a/backend/internal/service/openai_ws_forwarder_ingress_test.go b/backend/internal/service/openai_ws_forwarder_ingress_test.go
new file mode 100644
index 00000000..ff35cb01
--- /dev/null
+++ b/backend/internal/service/openai_ws_forwarder_ingress_test.go
@@ -0,0 +1,714 @@
+package service
+
+import (
+ "context"
+ "encoding/json"
+ "errors"
+ "io"
+ "net"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ coderws "github.com/coder/websocket"
+ "github.com/stretchr/testify/require"
+ "github.com/tidwall/gjson"
+)
+
+func TestIsOpenAIWSClientDisconnectError(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ err error
+ want bool
+ }{
+ {name: "nil", err: nil, want: false},
+ {name: "io_eof", err: io.EOF, want: true},
+ {name: "net_closed", err: net.ErrClosed, want: true},
+ {name: "context_canceled", err: context.Canceled, want: true},
+ {name: "ws_normal_closure", err: coderws.CloseError{Code: coderws.StatusNormalClosure}, want: true},
+ {name: "ws_going_away", err: coderws.CloseError{Code: coderws.StatusGoingAway}, want: true},
+ {name: "ws_no_status", err: coderws.CloseError{Code: coderws.StatusNoStatusRcvd}, want: true},
+ {name: "ws_abnormal_1006", err: coderws.CloseError{Code: coderws.StatusAbnormalClosure}, want: true},
+ {name: "ws_policy_violation", err: coderws.CloseError{Code: coderws.StatusPolicyViolation}, want: false},
+ {name: "wrapped_eof_message", err: errors.New("failed to get reader: failed to read frame header: EOF"), want: true},
+ {name: "connection_reset_by_peer", err: errors.New("failed to read frame header: read tcp 127.0.0.1:1234->127.0.0.1:5678: read: connection reset by peer"), want: true},
+ {name: "broken_pipe", err: errors.New("write tcp 127.0.0.1:1234->127.0.0.1:5678: write: broken pipe"), want: true},
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ require.Equal(t, tt.want, isOpenAIWSClientDisconnectError(tt.err))
+ })
+ }
+}
+
+func TestIsOpenAIWSIngressPreviousResponseNotFound(t *testing.T) {
+ t.Parallel()
+
+ require.False(t, isOpenAIWSIngressPreviousResponseNotFound(nil))
+ require.False(t, isOpenAIWSIngressPreviousResponseNotFound(errors.New("plain error")))
+ require.False(t, isOpenAIWSIngressPreviousResponseNotFound(
+ wrapOpenAIWSIngressTurnError("read_upstream", errors.New("upstream read failed"), false),
+ ))
+ require.False(t, isOpenAIWSIngressPreviousResponseNotFound(
+ wrapOpenAIWSIngressTurnError(openAIWSIngressStagePreviousResponseNotFound, errors.New("previous response not found"), true),
+ ))
+ require.True(t, isOpenAIWSIngressPreviousResponseNotFound(
+ wrapOpenAIWSIngressTurnError(openAIWSIngressStagePreviousResponseNotFound, errors.New("previous response not found"), false),
+ ))
+}
+
+func TestOpenAIWSIngressPreviousResponseRecoveryEnabled(t *testing.T) {
+ t.Parallel()
+
+ var nilService *OpenAIGatewayService
+ require.True(t, nilService.openAIWSIngressPreviousResponseRecoveryEnabled(), "nil service should default to enabled")
+
+ svcWithNilCfg := &OpenAIGatewayService{}
+ require.True(t, svcWithNilCfg.openAIWSIngressPreviousResponseRecoveryEnabled(), "nil config should default to enabled")
+
+ svc := &OpenAIGatewayService{
+ cfg: &config.Config{},
+ }
+ require.False(t, svc.openAIWSIngressPreviousResponseRecoveryEnabled(), "explicit config default should be false")
+
+ svc.cfg.Gateway.OpenAIWS.IngressPreviousResponseRecoveryEnabled = true
+ require.True(t, svc.openAIWSIngressPreviousResponseRecoveryEnabled())
+}
+
+func TestDropPreviousResponseIDFromRawPayload(t *testing.T) {
+ t.Parallel()
+
+ t.Run("empty_payload", func(t *testing.T) {
+ updated, removed, err := dropPreviousResponseIDFromRawPayload(nil)
+ require.NoError(t, err)
+ require.False(t, removed)
+ require.Empty(t, updated)
+ })
+
+ t.Run("payload_without_previous_response_id", func(t *testing.T) {
+ payload := []byte(`{"type":"response.create","model":"gpt-5.1"}`)
+ updated, removed, err := dropPreviousResponseIDFromRawPayload(payload)
+ require.NoError(t, err)
+ require.False(t, removed)
+ require.Equal(t, string(payload), string(updated))
+ })
+
+ t.Run("normal_delete_success", func(t *testing.T) {
+ payload := []byte(`{"type":"response.create","model":"gpt-5.1","previous_response_id":"resp_abc"}`)
+ updated, removed, err := dropPreviousResponseIDFromRawPayload(payload)
+ require.NoError(t, err)
+ require.True(t, removed)
+ require.False(t, gjson.GetBytes(updated, "previous_response_id").Exists())
+ })
+
+ t.Run("duplicate_keys_are_removed", func(t *testing.T) {
+ payload := []byte(`{"type":"response.create","previous_response_id":"resp_a","input":[],"previous_response_id":"resp_b"}`)
+ updated, removed, err := dropPreviousResponseIDFromRawPayload(payload)
+ require.NoError(t, err)
+ require.True(t, removed)
+ require.False(t, gjson.GetBytes(updated, "previous_response_id").Exists())
+ })
+
+ t.Run("nil_delete_fn_uses_default_delete_logic", func(t *testing.T) {
+ payload := []byte(`{"type":"response.create","model":"gpt-5.1","previous_response_id":"resp_abc"}`)
+ updated, removed, err := dropPreviousResponseIDFromRawPayloadWithDeleteFn(payload, nil)
+ require.NoError(t, err)
+ require.True(t, removed)
+ require.False(t, gjson.GetBytes(updated, "previous_response_id").Exists())
+ })
+
+ t.Run("delete_error", func(t *testing.T) {
+ payload := []byte(`{"type":"response.create","model":"gpt-5.1","previous_response_id":"resp_abc"}`)
+ updated, removed, err := dropPreviousResponseIDFromRawPayloadWithDeleteFn(payload, func(_ []byte, _ string) ([]byte, error) {
+ return nil, errors.New("delete failed")
+ })
+ require.Error(t, err)
+ require.False(t, removed)
+ require.Equal(t, string(payload), string(updated))
+ })
+
+ t.Run("malformed_json_is_still_best_effort_deleted", func(t *testing.T) {
+ payload := []byte(`{"type":"response.create","previous_response_id":"resp_abc"`)
+ require.True(t, gjson.GetBytes(payload, "previous_response_id").Exists())
+
+ updated, removed, err := dropPreviousResponseIDFromRawPayload(payload)
+ require.NoError(t, err)
+ require.True(t, removed)
+ require.False(t, gjson.GetBytes(updated, "previous_response_id").Exists())
+ })
+}
+
+func TestAlignStoreDisabledPreviousResponseID(t *testing.T) {
+ t.Parallel()
+
+ t.Run("empty_payload", func(t *testing.T) {
+ updated, changed, err := alignStoreDisabledPreviousResponseID(nil, "resp_target")
+ require.NoError(t, err)
+ require.False(t, changed)
+ require.Empty(t, updated)
+ })
+
+ t.Run("empty_expected", func(t *testing.T) {
+ payload := []byte(`{"type":"response.create","previous_response_id":"resp_old"}`)
+ updated, changed, err := alignStoreDisabledPreviousResponseID(payload, "")
+ require.NoError(t, err)
+ require.False(t, changed)
+ require.Equal(t, string(payload), string(updated))
+ })
+
+ t.Run("missing_previous_response_id", func(t *testing.T) {
+ payload := []byte(`{"type":"response.create","model":"gpt-5.1"}`)
+ updated, changed, err := alignStoreDisabledPreviousResponseID(payload, "resp_target")
+ require.NoError(t, err)
+ require.False(t, changed)
+ require.Equal(t, string(payload), string(updated))
+ })
+
+ t.Run("already_aligned", func(t *testing.T) {
+ payload := []byte(`{"type":"response.create","previous_response_id":"resp_target"}`)
+ updated, changed, err := alignStoreDisabledPreviousResponseID(payload, "resp_target")
+ require.NoError(t, err)
+ require.False(t, changed)
+ require.Equal(t, "resp_target", gjson.GetBytes(updated, "previous_response_id").String())
+ })
+
+ t.Run("mismatch_rewrites_to_expected", func(t *testing.T) {
+ payload := []byte(`{"type":"response.create","previous_response_id":"resp_old","input":[]}`)
+ updated, changed, err := alignStoreDisabledPreviousResponseID(payload, "resp_target")
+ require.NoError(t, err)
+ require.True(t, changed)
+ require.Equal(t, "resp_target", gjson.GetBytes(updated, "previous_response_id").String())
+ })
+
+ t.Run("duplicate_keys_rewrites_to_single_expected", func(t *testing.T) {
+ payload := []byte(`{"type":"response.create","previous_response_id":"resp_old_1","input":[],"previous_response_id":"resp_old_2"}`)
+ updated, changed, err := alignStoreDisabledPreviousResponseID(payload, "resp_target")
+ require.NoError(t, err)
+ require.True(t, changed)
+ require.Equal(t, "resp_target", gjson.GetBytes(updated, "previous_response_id").String())
+ })
+}
+
+func TestSetPreviousResponseIDToRawPayload(t *testing.T) {
+ t.Parallel()
+
+ t.Run("empty_payload", func(t *testing.T) {
+ updated, err := setPreviousResponseIDToRawPayload(nil, "resp_target")
+ require.NoError(t, err)
+ require.Empty(t, updated)
+ })
+
+ t.Run("empty_previous_response_id", func(t *testing.T) {
+ payload := []byte(`{"type":"response.create","model":"gpt-5.1"}`)
+ updated, err := setPreviousResponseIDToRawPayload(payload, "")
+ require.NoError(t, err)
+ require.Equal(t, string(payload), string(updated))
+ })
+
+ t.Run("set_previous_response_id_when_missing", func(t *testing.T) {
+ payload := []byte(`{"type":"response.create","model":"gpt-5.1"}`)
+ updated, err := setPreviousResponseIDToRawPayload(payload, "resp_target")
+ require.NoError(t, err)
+ require.Equal(t, "resp_target", gjson.GetBytes(updated, "previous_response_id").String())
+ require.Equal(t, "gpt-5.1", gjson.GetBytes(updated, "model").String())
+ })
+
+ t.Run("overwrite_existing_previous_response_id", func(t *testing.T) {
+ payload := []byte(`{"type":"response.create","model":"gpt-5.1","previous_response_id":"resp_old"}`)
+ updated, err := setPreviousResponseIDToRawPayload(payload, "resp_new")
+ require.NoError(t, err)
+ require.Equal(t, "resp_new", gjson.GetBytes(updated, "previous_response_id").String())
+ })
+}
+
+func TestShouldInferIngressFunctionCallOutputPreviousResponseID(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ storeDisabled bool
+ turn int
+ hasFunctionCallOutput bool
+ currentPreviousResponse string
+ expectedPrevious string
+ want bool
+ }{
+ {
+ name: "infer_when_all_conditions_match",
+ storeDisabled: true,
+ turn: 2,
+ hasFunctionCallOutput: true,
+ expectedPrevious: "resp_1",
+ want: true,
+ },
+ {
+ name: "skip_when_store_enabled",
+ storeDisabled: false,
+ turn: 2,
+ hasFunctionCallOutput: true,
+ expectedPrevious: "resp_1",
+ want: false,
+ },
+ {
+ name: "skip_on_first_turn",
+ storeDisabled: true,
+ turn: 1,
+ hasFunctionCallOutput: true,
+ expectedPrevious: "resp_1",
+ want: false,
+ },
+ {
+ name: "skip_without_function_call_output",
+ storeDisabled: true,
+ turn: 2,
+ hasFunctionCallOutput: false,
+ expectedPrevious: "resp_1",
+ want: false,
+ },
+ {
+ name: "skip_when_request_already_has_previous_response_id",
+ storeDisabled: true,
+ turn: 2,
+ hasFunctionCallOutput: true,
+ currentPreviousResponse: "resp_client",
+ expectedPrevious: "resp_1",
+ want: false,
+ },
+ {
+ name: "skip_when_last_turn_response_id_missing",
+ storeDisabled: true,
+ turn: 2,
+ hasFunctionCallOutput: true,
+ expectedPrevious: "",
+ want: false,
+ },
+ {
+ name: "trim_whitespace_before_judgement",
+ storeDisabled: true,
+ turn: 2,
+ hasFunctionCallOutput: true,
+ expectedPrevious: " resp_2 ",
+ want: true,
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ got := shouldInferIngressFunctionCallOutputPreviousResponseID(
+ tt.storeDisabled,
+ tt.turn,
+ tt.hasFunctionCallOutput,
+ tt.currentPreviousResponse,
+ tt.expectedPrevious,
+ )
+ require.Equal(t, tt.want, got)
+ })
+ }
+}
+
+func TestOpenAIWSInputIsPrefixExtended(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ previous []byte
+ current []byte
+ want bool
+ expectErr bool
+ }{
+ {
+ name: "both_missing_input",
+ previous: []byte(`{"type":"response.create","model":"gpt-5.1"}`),
+ current: []byte(`{"type":"response.create","model":"gpt-5.1","previous_response_id":"resp_1"}`),
+ want: true,
+ },
+ {
+ name: "previous_missing_current_empty_array",
+ previous: []byte(`{"type":"response.create","model":"gpt-5.1"}`),
+ current: []byte(`{"type":"response.create","model":"gpt-5.1","input":[]}`),
+ want: true,
+ },
+ {
+ name: "previous_missing_current_non_empty_array",
+ previous: []byte(`{"type":"response.create","model":"gpt-5.1"}`),
+ current: []byte(`{"type":"response.create","model":"gpt-5.1","input":[{"type":"input_text","text":"hello"}]}`),
+ want: false,
+ },
+ {
+ name: "array_prefix_match",
+ previous: []byte(`{"input":[{"type":"input_text","text":"hello"}]}`),
+ current: []byte(`{"input":[{"text":"hello","type":"input_text"},{"type":"input_text","text":"world"}]}`),
+ want: true,
+ },
+ {
+ name: "array_prefix_mismatch",
+ previous: []byte(`{"input":[{"type":"input_text","text":"hello"}]}`),
+ current: []byte(`{"input":[{"type":"input_text","text":"different"}]}`),
+ want: false,
+ },
+ {
+ name: "current_shorter_than_previous",
+ previous: []byte(`{"input":[{"type":"input_text","text":"a"},{"type":"input_text","text":"b"}]}`),
+ current: []byte(`{"input":[{"type":"input_text","text":"a"}]}`),
+ want: false,
+ },
+ {
+ name: "previous_has_input_current_missing",
+ previous: []byte(`{"input":[{"type":"input_text","text":"a"}]}`),
+ current: []byte(`{"model":"gpt-5.1"}`),
+ want: false,
+ },
+ {
+ name: "input_string_treated_as_single_item",
+ previous: []byte(`{"input":"hello"}`),
+ current: []byte(`{"input":"hello"}`),
+ want: true,
+ },
+ {
+ name: "current_invalid_input_json",
+ previous: []byte(`{"input":[]}`),
+ current: []byte(`{"input":[}`),
+ expectErr: true,
+ },
+ {
+ name: "invalid_input_json",
+ previous: []byte(`{"input":[}`),
+ current: []byte(`{"input":[]}`),
+ expectErr: true,
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ got, err := openAIWSInputIsPrefixExtended(tt.previous, tt.current)
+ if tt.expectErr {
+ require.Error(t, err)
+ return
+ }
+ require.NoError(t, err)
+ require.Equal(t, tt.want, got)
+ })
+ }
+}
+
+func TestNormalizeOpenAIWSJSONForCompare(t *testing.T) {
+ t.Parallel()
+
+ normalized, err := normalizeOpenAIWSJSONForCompare([]byte(`{"b":2,"a":1}`))
+ require.NoError(t, err)
+ require.Equal(t, `{"a":1,"b":2}`, string(normalized))
+
+ _, err = normalizeOpenAIWSJSONForCompare([]byte(" "))
+ require.Error(t, err)
+
+ _, err = normalizeOpenAIWSJSONForCompare([]byte(`{"a":`))
+ require.Error(t, err)
+}
+
+func TestNormalizeOpenAIWSJSONForCompareOrRaw(t *testing.T) {
+ t.Parallel()
+
+ require.Equal(t, `{"a":1,"b":2}`, string(normalizeOpenAIWSJSONForCompareOrRaw([]byte(`{"b":2,"a":1}`))))
+ require.Equal(t, `{"a":`, string(normalizeOpenAIWSJSONForCompareOrRaw([]byte(`{"a":`))))
+}
+
+func TestNormalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(t *testing.T) {
+ t.Parallel()
+
+ normalized, err := normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(
+ []byte(`{"model":"gpt-5.1","input":[1],"previous_response_id":"resp_x","metadata":{"b":2,"a":1}}`),
+ )
+ require.NoError(t, err)
+ require.False(t, gjson.GetBytes(normalized, "input").Exists())
+ require.False(t, gjson.GetBytes(normalized, "previous_response_id").Exists())
+ require.Equal(t, float64(1), gjson.GetBytes(normalized, "metadata.a").Float())
+
+ _, err = normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(nil)
+ require.Error(t, err)
+
+ _, err = normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID([]byte(`[]`))
+ require.Error(t, err)
+}
+
+func TestOpenAIWSExtractNormalizedInputSequence(t *testing.T) {
+ t.Parallel()
+
+ t.Run("empty_payload", func(t *testing.T) {
+ items, exists, err := openAIWSExtractNormalizedInputSequence(nil)
+ require.NoError(t, err)
+ require.False(t, exists)
+ require.Nil(t, items)
+ })
+
+ t.Run("input_missing", func(t *testing.T) {
+ items, exists, err := openAIWSExtractNormalizedInputSequence([]byte(`{"type":"response.create"}`))
+ require.NoError(t, err)
+ require.False(t, exists)
+ require.Nil(t, items)
+ })
+
+ t.Run("input_array", func(t *testing.T) {
+ items, exists, err := openAIWSExtractNormalizedInputSequence([]byte(`{"input":[{"type":"input_text","text":"hello"}]}`))
+ require.NoError(t, err)
+ require.True(t, exists)
+ require.Len(t, items, 1)
+ })
+
+ t.Run("input_object", func(t *testing.T) {
+ items, exists, err := openAIWSExtractNormalizedInputSequence([]byte(`{"input":{"type":"input_text","text":"hello"}}`))
+ require.NoError(t, err)
+ require.True(t, exists)
+ require.Len(t, items, 1)
+ })
+
+ t.Run("input_string", func(t *testing.T) {
+ items, exists, err := openAIWSExtractNormalizedInputSequence([]byte(`{"input":"hello"}`))
+ require.NoError(t, err)
+ require.True(t, exists)
+ require.Len(t, items, 1)
+ require.Equal(t, `"hello"`, string(items[0]))
+ })
+
+ t.Run("input_number", func(t *testing.T) {
+ items, exists, err := openAIWSExtractNormalizedInputSequence([]byte(`{"input":42}`))
+ require.NoError(t, err)
+ require.True(t, exists)
+ require.Len(t, items, 1)
+ require.Equal(t, "42", string(items[0]))
+ })
+
+ t.Run("input_bool", func(t *testing.T) {
+ items, exists, err := openAIWSExtractNormalizedInputSequence([]byte(`{"input":true}`))
+ require.NoError(t, err)
+ require.True(t, exists)
+ require.Len(t, items, 1)
+ require.Equal(t, "true", string(items[0]))
+ })
+
+ t.Run("input_null", func(t *testing.T) {
+ items, exists, err := openAIWSExtractNormalizedInputSequence([]byte(`{"input":null}`))
+ require.NoError(t, err)
+ require.True(t, exists)
+ require.Len(t, items, 1)
+ require.Equal(t, "null", string(items[0]))
+ })
+
+ t.Run("input_invalid_array_json", func(t *testing.T) {
+ items, exists, err := openAIWSExtractNormalizedInputSequence([]byte(`{"input":[}`))
+ require.Error(t, err)
+ require.True(t, exists)
+ require.Nil(t, items)
+ })
+}
+
+func TestShouldKeepIngressPreviousResponseID(t *testing.T) {
+ t.Parallel()
+
+ previousPayload := []byte(`{
+ "type":"response.create",
+ "model":"gpt-5.1",
+ "store":false,
+ "tools":[{"type":"function","name":"tool_a"}],
+ "input":[{"type":"input_text","text":"hello"}]
+ }`)
+ currentStrictPayload := []byte(`{
+ "type":"response.create",
+ "model":"gpt-5.1",
+ "store":false,
+ "tools":[{"name":"tool_a","type":"function"}],
+ "previous_response_id":"resp_turn_1",
+ "input":[{"text":"hello","type":"input_text"},{"type":"input_text","text":"world"}]
+ }`)
+
+ t.Run("strict_incremental_keep", func(t *testing.T) {
+ keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, currentStrictPayload, "resp_turn_1", false)
+ require.NoError(t, err)
+ require.True(t, keep)
+ require.Equal(t, "strict_incremental_ok", reason)
+ })
+
+ t.Run("missing_previous_response_id", func(t *testing.T) {
+ payload := []byte(`{"type":"response.create","model":"gpt-5.1","input":[]}`)
+ keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, payload, "resp_turn_1", false)
+ require.NoError(t, err)
+ require.False(t, keep)
+ require.Equal(t, "missing_previous_response_id", reason)
+ })
+
+ t.Run("missing_last_turn_response_id", func(t *testing.T) {
+ keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, currentStrictPayload, "", false)
+ require.NoError(t, err)
+ require.False(t, keep)
+ require.Equal(t, "missing_last_turn_response_id", reason)
+ })
+
+ t.Run("previous_response_id_mismatch", func(t *testing.T) {
+ keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, currentStrictPayload, "resp_turn_other", false)
+ require.NoError(t, err)
+ require.False(t, keep)
+ require.Equal(t, "previous_response_id_mismatch", reason)
+ })
+
+ t.Run("missing_previous_turn_payload", func(t *testing.T) {
+ keep, reason, err := shouldKeepIngressPreviousResponseID(nil, currentStrictPayload, "resp_turn_1", false)
+ require.NoError(t, err)
+ require.False(t, keep)
+ require.Equal(t, "missing_previous_turn_payload", reason)
+ })
+
+ t.Run("non_input_changed", func(t *testing.T) {
+ payload := []byte(`{
+ "type":"response.create",
+ "model":"gpt-5.1-mini",
+ "store":false,
+ "tools":[{"type":"function","name":"tool_a"}],
+ "previous_response_id":"resp_turn_1",
+ "input":[{"type":"input_text","text":"hello"},{"type":"input_text","text":"world"}]
+ }`)
+ keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, payload, "resp_turn_1", false)
+ require.NoError(t, err)
+ require.False(t, keep)
+ require.Equal(t, "non_input_changed", reason)
+ })
+
+ t.Run("delta_input_keeps_previous_response_id", func(t *testing.T) {
+ payload := []byte(`{
+ "type":"response.create",
+ "model":"gpt-5.1",
+ "store":false,
+ "tools":[{"type":"function","name":"tool_a"}],
+ "previous_response_id":"resp_turn_1",
+ "input":[{"type":"input_text","text":"different"}]
+ }`)
+ keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, payload, "resp_turn_1", false)
+ require.NoError(t, err)
+ require.True(t, keep)
+ require.Equal(t, "strict_incremental_ok", reason)
+ })
+
+ t.Run("function_call_output_keeps_previous_response_id", func(t *testing.T) {
+ payload := []byte(`{
+ "type":"response.create",
+ "model":"gpt-5.1",
+ "store":false,
+ "previous_response_id":"resp_external",
+ "input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]
+ }`)
+ keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, payload, "resp_turn_1", true)
+ require.NoError(t, err)
+ require.True(t, keep)
+ require.Equal(t, "has_function_call_output", reason)
+ })
+
+ t.Run("non_input_compare_error", func(t *testing.T) {
+ keep, reason, err := shouldKeepIngressPreviousResponseID([]byte(`[]`), currentStrictPayload, "resp_turn_1", false)
+ require.Error(t, err)
+ require.False(t, keep)
+ require.Equal(t, "non_input_compare_error", reason)
+ })
+
+ t.Run("current_payload_compare_error", func(t *testing.T) {
+ keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, []byte(`{"previous_response_id":"resp_turn_1","input":[}`), "resp_turn_1", false)
+ require.Error(t, err)
+ require.False(t, keep)
+ require.Equal(t, "non_input_compare_error", reason)
+ })
+}
+
+func TestBuildOpenAIWSReplayInputSequence(t *testing.T) {
+ t.Parallel()
+
+ lastFull := []json.RawMessage{
+ json.RawMessage(`{"type":"input_text","text":"hello"}`),
+ }
+
+ t.Run("no_previous_response_id_use_current", func(t *testing.T) {
+ items, exists, err := buildOpenAIWSReplayInputSequence(
+ lastFull,
+ true,
+ []byte(`{"input":[{"type":"input_text","text":"new"}]}`),
+ false,
+ )
+ require.NoError(t, err)
+ require.True(t, exists)
+ require.Len(t, items, 1)
+ require.Equal(t, "new", gjson.GetBytes(items[0], "text").String())
+ })
+
+ t.Run("previous_response_id_delta_append", func(t *testing.T) {
+ items, exists, err := buildOpenAIWSReplayInputSequence(
+ lastFull,
+ true,
+ []byte(`{"previous_response_id":"resp_1","input":[{"type":"input_text","text":"world"}]}`),
+ true,
+ )
+ require.NoError(t, err)
+ require.True(t, exists)
+ require.Len(t, items, 2)
+ require.Equal(t, "hello", gjson.GetBytes(items[0], "text").String())
+ require.Equal(t, "world", gjson.GetBytes(items[1], "text").String())
+ })
+
+ t.Run("previous_response_id_full_input_replace", func(t *testing.T) {
+ items, exists, err := buildOpenAIWSReplayInputSequence(
+ lastFull,
+ true,
+ []byte(`{"previous_response_id":"resp_1","input":[{"type":"input_text","text":"hello"},{"type":"input_text","text":"world"}]}`),
+ true,
+ )
+ require.NoError(t, err)
+ require.True(t, exists)
+ require.Len(t, items, 2)
+ require.Equal(t, "hello", gjson.GetBytes(items[0], "text").String())
+ require.Equal(t, "world", gjson.GetBytes(items[1], "text").String())
+ })
+}
+
+func TestSetOpenAIWSPayloadInputSequence(t *testing.T) {
+ t.Parallel()
+
+ t.Run("set_items", func(t *testing.T) {
+ original := []byte(`{"type":"response.create","previous_response_id":"resp_1"}`)
+ items := []json.RawMessage{
+ json.RawMessage(`{"type":"input_text","text":"hello"}`),
+ json.RawMessage(`{"type":"input_text","text":"world"}`),
+ }
+ updated, err := setOpenAIWSPayloadInputSequence(original, items, true)
+ require.NoError(t, err)
+ require.Equal(t, "hello", gjson.GetBytes(updated, "input.0.text").String())
+ require.Equal(t, "world", gjson.GetBytes(updated, "input.1.text").String())
+ })
+
+ t.Run("preserve_empty_array_not_null", func(t *testing.T) {
+ original := []byte(`{"type":"response.create","previous_response_id":"resp_1"}`)
+ updated, err := setOpenAIWSPayloadInputSequence(original, nil, true)
+ require.NoError(t, err)
+ require.True(t, gjson.GetBytes(updated, "input").IsArray())
+ require.Len(t, gjson.GetBytes(updated, "input").Array(), 0)
+ require.False(t, gjson.GetBytes(updated, "input").Type == gjson.Null)
+ })
+}
+
+func TestCloneOpenAIWSRawMessages(t *testing.T) {
+ t.Parallel()
+
+ t.Run("nil_slice", func(t *testing.T) {
+ cloned := cloneOpenAIWSRawMessages(nil)
+ require.Nil(t, cloned)
+ })
+
+ t.Run("empty_slice", func(t *testing.T) {
+ items := make([]json.RawMessage, 0)
+ cloned := cloneOpenAIWSRawMessages(items)
+ require.NotNil(t, cloned)
+ require.Len(t, cloned, 0)
+ })
+}
diff --git a/backend/internal/service/openai_ws_forwarder_retry_payload_test.go b/backend/internal/service/openai_ws_forwarder_retry_payload_test.go
new file mode 100644
index 00000000..0ea7e1c7
--- /dev/null
+++ b/backend/internal/service/openai_ws_forwarder_retry_payload_test.go
@@ -0,0 +1,50 @@
+package service
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestApplyOpenAIWSRetryPayloadStrategy_KeepPromptCacheKey(t *testing.T) {
+ payload := map[string]any{
+ "model": "gpt-5.3-codex",
+ "prompt_cache_key": "pcache_123",
+ "include": []any{"reasoning.encrypted_content"},
+ "text": map[string]any{
+ "verbosity": "low",
+ },
+ "tools": []any{map[string]any{"type": "function"}},
+ }
+
+ strategy, removed := applyOpenAIWSRetryPayloadStrategy(payload, 3)
+ require.Equal(t, "trim_optional_fields", strategy)
+ require.Contains(t, removed, "include")
+ require.NotContains(t, removed, "prompt_cache_key")
+ require.Equal(t, "pcache_123", payload["prompt_cache_key"])
+ require.NotContains(t, payload, "include")
+ require.Contains(t, payload, "text")
+}
+
+func TestApplyOpenAIWSRetryPayloadStrategy_AttemptSixKeepsSemanticFields(t *testing.T) {
+ payload := map[string]any{
+ "prompt_cache_key": "pcache_456",
+ "instructions": "long instructions",
+ "tools": []any{map[string]any{"type": "function"}},
+ "parallel_tool_calls": true,
+ "tool_choice": "auto",
+ "include": []any{"reasoning.encrypted_content"},
+ "text": map[string]any{"verbosity": "high"},
+ }
+
+ strategy, removed := applyOpenAIWSRetryPayloadStrategy(payload, 6)
+ require.Equal(t, "trim_optional_fields", strategy)
+ require.Contains(t, removed, "include")
+ require.NotContains(t, removed, "prompt_cache_key")
+ require.Equal(t, "pcache_456", payload["prompt_cache_key"])
+ require.Contains(t, payload, "instructions")
+ require.Contains(t, payload, "tools")
+ require.Contains(t, payload, "tool_choice")
+ require.Contains(t, payload, "parallel_tool_calls")
+ require.Contains(t, payload, "text")
+}
diff --git a/backend/internal/service/openai_ws_forwarder_success_test.go b/backend/internal/service/openai_ws_forwarder_success_test.go
new file mode 100644
index 00000000..592801f6
--- /dev/null
+++ b/backend/internal/service/openai_ws_forwarder_success_test.go
@@ -0,0 +1,1306 @@
+package service
+
+import (
+ "context"
+ "encoding/json"
+ "errors"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "strconv"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/gin-gonic/gin"
+ "github.com/gorilla/websocket"
+ "github.com/stretchr/testify/require"
+ "github.com/tidwall/gjson"
+)
+
+func TestOpenAIGatewayService_Forward_WSv2_SuccessAndBindSticky(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ type receivedPayload struct {
+ Type string
+ PreviousResponseID string
+ StreamExists bool
+ Stream bool
+ }
+ receivedCh := make(chan receivedPayload, 1)
+
+ upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
+ wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ conn, err := upgrader.Upgrade(w, r, nil)
+ if err != nil {
+ t.Errorf("upgrade websocket failed: %v", err)
+ return
+ }
+ defer func() {
+ _ = conn.Close()
+ }()
+
+ var request map[string]any
+ if err := conn.ReadJSON(&request); err != nil {
+ t.Errorf("read ws request failed: %v", err)
+ return
+ }
+ requestJSON := requestToJSONString(request)
+ receivedCh <- receivedPayload{
+ Type: strings.TrimSpace(gjson.Get(requestJSON, "type").String()),
+ PreviousResponseID: strings.TrimSpace(gjson.Get(requestJSON, "previous_response_id").String()),
+ StreamExists: gjson.Get(requestJSON, "stream").Exists(),
+ Stream: gjson.Get(requestJSON, "stream").Bool(),
+ }
+
+ if err := conn.WriteJSON(map[string]any{
+ "type": "response.created",
+ "response": map[string]any{
+ "id": "resp_new_1",
+ "model": "gpt-5.1",
+ },
+ }); err != nil {
+ t.Errorf("write response.created failed: %v", err)
+ return
+ }
+ if err := conn.WriteJSON(map[string]any{
+ "type": "response.completed",
+ "response": map[string]any{
+ "id": "resp_new_1",
+ "model": "gpt-5.1",
+ "usage": map[string]any{
+ "input_tokens": 12,
+ "output_tokens": 7,
+ "input_tokens_details": map[string]any{
+ "cached_tokens": 3,
+ },
+ },
+ },
+ }); err != nil {
+ t.Errorf("write response.completed failed: %v", err)
+ return
+ }
+ }))
+ defer wsServer.Close()
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
+ c.Request.Header.Set("User-Agent", "unit-test-agent/1.0")
+ groupID := int64(1001)
+ c.Set("api_key", &APIKey{GroupID: &groupID})
+
+ cfg := &config.Config{}
+ cfg.Security.URLAllowlist.Enabled = false
+ cfg.Security.URLAllowlist.AllowInsecureHTTP = true
+ cfg.Gateway.OpenAIWS.Enabled = true
+ cfg.Gateway.OpenAIWS.OAuthEnabled = true
+ cfg.Gateway.OpenAIWS.APIKeyEnabled = true
+ cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
+ cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2
+ cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
+ cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
+ cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 30
+ cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 10
+ cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600
+
+ upstream := &httpUpstreamRecorder{
+ resp: &http.Response{
+ StatusCode: http.StatusOK,
+ Header: http.Header{"Content-Type": []string{"application/json"}},
+ Body: io.NopCloser(strings.NewReader(`{"usage":{"input_tokens":1,"output_tokens":1}}`)),
+ },
+ }
+
+ cache := &stubGatewayCache{}
+ svc := &OpenAIGatewayService{
+ cfg: cfg,
+ httpUpstream: upstream,
+ cache: cache,
+ openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
+ toolCorrector: NewCodexToolCorrector(),
+ }
+
+ account := &Account{
+ ID: 9,
+ Name: "openai-ws",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 2,
+ Credentials: map[string]any{
+ "api_key": "sk-test",
+ "base_url": wsServer.URL,
+ },
+ Extra: map[string]any{
+ "responses_websockets_v2_enabled": true,
+ },
+ }
+
+ body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_prev_1","input":[{"type":"input_text","text":"hello"}]}`)
+ result, err := svc.Forward(context.Background(), c, account, body)
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.Equal(t, 12, result.Usage.InputTokens)
+ require.Equal(t, 7, result.Usage.OutputTokens)
+ require.Equal(t, 3, result.Usage.CacheReadInputTokens)
+ require.Equal(t, "resp_new_1", result.RequestID)
+ require.True(t, result.OpenAIWSMode)
+ require.False(t, gjson.GetBytes(upstream.lastBody, "model").Exists(), "WSv2 成功时不应回落 HTTP 上游")
+
+ received := <-receivedCh
+ require.Equal(t, "response.create", received.Type)
+ require.Equal(t, "resp_prev_1", received.PreviousResponseID)
+ require.True(t, received.StreamExists, "WS 请求应携带 stream 字段")
+ require.False(t, received.Stream, "应保持客户端 stream=false 的原始语义")
+
+ store := svc.getOpenAIWSStateStore()
+ mappedAccountID, getErr := store.GetResponseAccount(context.Background(), groupID, "resp_new_1")
+ require.NoError(t, getErr)
+ require.Equal(t, account.ID, mappedAccountID)
+ connID, ok := store.GetResponseConn("resp_new_1")
+ require.True(t, ok)
+ require.NotEmpty(t, connID)
+
+ responseBody := rec.Body.Bytes()
+ require.Equal(t, "resp_new_1", gjson.GetBytes(responseBody, "id").String())
+}
+
+func requestToJSONString(payload map[string]any) string {
+ if len(payload) == 0 {
+ return "{}"
+ }
+ b, err := json.Marshal(payload)
+ if err != nil {
+ return "{}"
+ }
+ return string(b)
+}
+
+func TestLogOpenAIWSBindResponseAccountWarn(t *testing.T) {
+ require.NotPanics(t, func() {
+ logOpenAIWSBindResponseAccountWarn(1, 2, "resp_ok", nil)
+ })
+ require.NotPanics(t, func() {
+ logOpenAIWSBindResponseAccountWarn(1, 2, "resp_err", errors.New("bind failed"))
+ })
+}
+
+func TestOpenAIGatewayService_Forward_WSv2_RewriteModelAndToolCallsOnCompletedEvent(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
+ c.Request.Header.Set("User-Agent", "codex_cli_rs/0.98.0")
+ groupID := int64(3001)
+ c.Set("api_key", &APIKey{GroupID: &groupID})
+
+ cfg := &config.Config{}
+ cfg.Security.URLAllowlist.Enabled = false
+ cfg.Security.URLAllowlist.AllowInsecureHTTP = true
+ cfg.Gateway.OpenAIWS.Enabled = true
+ cfg.Gateway.OpenAIWS.OAuthEnabled = true
+ cfg.Gateway.OpenAIWS.APIKeyEnabled = true
+ cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
+ cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
+ cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
+ cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
+ cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
+ cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
+ cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 5
+ cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
+
+ captureConn := &openAIWSCaptureConn{
+ events: [][]byte{
+ []byte(`{"type":"response.completed","response":{"id":"resp_model_tool_1","model":"gpt-5.1","tool_calls":[{"function":{"name":"apply_patch","arguments":"{\"file_path\":\"/tmp/a.txt\",\"old_string\":\"a\",\"new_string\":\"b\"}"}}],"usage":{"input_tokens":2,"output_tokens":1}},"tool_calls":[{"function":{"name":"apply_patch","arguments":"{\"file_path\":\"/tmp/a.txt\",\"old_string\":\"a\",\"new_string\":\"b\"}"}}]}`),
+ },
+ }
+ captureDialer := &openAIWSCaptureDialer{conn: captureConn}
+ pool := newOpenAIWSConnPool(cfg)
+ pool.setClientDialerForTest(captureDialer)
+
+ svc := &OpenAIGatewayService{
+ cfg: cfg,
+ httpUpstream: &httpUpstreamRecorder{},
+ cache: &stubGatewayCache{},
+ openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
+ toolCorrector: NewCodexToolCorrector(),
+ openaiWSPool: pool,
+ }
+
+ account := &Account{
+ ID: 1301,
+ Name: "openai-rewrite",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "api_key": "sk-test",
+ "model_mapping": map[string]any{
+ "custom-original-model": "gpt-5.1",
+ },
+ },
+ Extra: map[string]any{
+ "responses_websockets_v2_enabled": true,
+ },
+ }
+
+ body := []byte(`{"model":"custom-original-model","stream":false,"input":[{"type":"input_text","text":"hello"}]}`)
+ result, err := svc.Forward(context.Background(), c, account, body)
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.Equal(t, "resp_model_tool_1", result.RequestID)
+ require.Equal(t, "custom-original-model", gjson.GetBytes(rec.Body.Bytes(), "model").String(), "响应模型应回写为原始请求模型")
+ require.Equal(t, "edit", gjson.GetBytes(rec.Body.Bytes(), "tool_calls.0.function.name").String(), "工具名称应被修正为 OpenCode 规范")
+}
+
+func TestOpenAIWSPayloadString_OnlyAcceptsStringValues(t *testing.T) {
+ payload := map[string]any{
+ "type": nil,
+ "model": 123,
+ "prompt_cache_key": " cache-key ",
+ "previous_response_id": []byte(" resp_1 "),
+ }
+
+ require.Equal(t, "", openAIWSPayloadString(payload, "type"))
+ require.Equal(t, "", openAIWSPayloadString(payload, "model"))
+ require.Equal(t, "cache-key", openAIWSPayloadString(payload, "prompt_cache_key"))
+ require.Equal(t, "resp_1", openAIWSPayloadString(payload, "previous_response_id"))
+}
+
+func TestOpenAIGatewayService_Forward_WSv2_PoolReuseNotOneToOne(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ var upgradeCount atomic.Int64
+ var sequence atomic.Int64
+ upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
+ wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ upgradeCount.Add(1)
+ conn, err := upgrader.Upgrade(w, r, nil)
+ if err != nil {
+ t.Errorf("upgrade websocket failed: %v", err)
+ return
+ }
+ defer func() {
+ _ = conn.Close()
+ }()
+
+ for {
+ var request map[string]any
+ if err := conn.ReadJSON(&request); err != nil {
+ return
+ }
+ idx := sequence.Add(1)
+ responseID := "resp_reuse_" + strconv.FormatInt(idx, 10)
+ if err := conn.WriteJSON(map[string]any{
+ "type": "response.created",
+ "response": map[string]any{
+ "id": responseID,
+ "model": "gpt-5.1",
+ },
+ }); err != nil {
+ return
+ }
+ if err := conn.WriteJSON(map[string]any{
+ "type": "response.completed",
+ "response": map[string]any{
+ "id": responseID,
+ "model": "gpt-5.1",
+ "usage": map[string]any{
+ "input_tokens": 2,
+ "output_tokens": 1,
+ },
+ },
+ }); err != nil {
+ return
+ }
+ }
+ }))
+ defer wsServer.Close()
+
+ cfg := &config.Config{}
+ cfg.Security.URLAllowlist.Enabled = false
+ cfg.Security.URLAllowlist.AllowInsecureHTTP = true
+ cfg.Gateway.OpenAIWS.Enabled = true
+ cfg.Gateway.OpenAIWS.OAuthEnabled = true
+ cfg.Gateway.OpenAIWS.APIKeyEnabled = true
+ cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
+ cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2
+ cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
+ cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
+ cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
+ cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
+ cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 30
+ cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 10
+
+ svc := &OpenAIGatewayService{
+ cfg: cfg,
+ httpUpstream: &httpUpstreamRecorder{},
+ cache: &stubGatewayCache{},
+ openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
+ toolCorrector: NewCodexToolCorrector(),
+ }
+ account := &Account{
+ ID: 19,
+ Name: "openai-ws",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 2,
+ Credentials: map[string]any{
+ "api_key": "sk-test",
+ "base_url": wsServer.URL,
+ },
+ Extra: map[string]any{
+ "responses_websockets_v2_enabled": true,
+ },
+ }
+
+ for i := 0; i < 2; i++ {
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
+ c.Request.Header.Set("User-Agent", "codex_cli_rs/0.98.0")
+ groupID := int64(2001)
+ c.Set("api_key", &APIKey{GroupID: &groupID})
+
+ body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_prev_reuse","input":[{"type":"input_text","text":"hello"}]}`)
+ result, err := svc.Forward(context.Background(), c, account, body)
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.True(t, strings.HasPrefix(result.RequestID, "resp_reuse_"))
+ }
+
+ require.Equal(t, int64(1), upgradeCount.Load(), "多个客户端请求应复用账号连接池而不是 1:1 对等建链")
+ metrics := svc.SnapshotOpenAIWSPoolMetrics()
+ require.GreaterOrEqual(t, metrics.AcquireReuseTotal, int64(1))
+ require.GreaterOrEqual(t, metrics.ConnPickTotal, int64(1))
+}
+
+func TestOpenAIGatewayService_Forward_WSv2_OAuthStoreFalseByDefault(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
+ c.Request.Header.Set("User-Agent", "codex_cli_rs/0.98.0")
+ c.Request.Header.Set("session_id", "sess-oauth-1")
+ c.Request.Header.Set("conversation_id", "conv-oauth-1")
+
+ cfg := &config.Config{}
+ cfg.Security.URLAllowlist.Enabled = false
+ cfg.Security.URLAllowlist.AllowInsecureHTTP = true
+ cfg.Gateway.OpenAIWS.Enabled = true
+ cfg.Gateway.OpenAIWS.OAuthEnabled = true
+ cfg.Gateway.OpenAIWS.APIKeyEnabled = true
+ cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
+ cfg.Gateway.OpenAIWS.AllowStoreRecovery = false
+ cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
+ cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
+ cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
+
+ captureConn := &openAIWSCaptureConn{
+ events: [][]byte{
+ []byte(`{"type":"response.completed","response":{"id":"resp_oauth_1","model":"gpt-5.1","usage":{"input_tokens":3,"output_tokens":2}}}`),
+ },
+ }
+ captureDialer := &openAIWSCaptureDialer{conn: captureConn}
+ pool := newOpenAIWSConnPool(cfg)
+ pool.setClientDialerForTest(captureDialer)
+
+ svc := &OpenAIGatewayService{
+ cfg: cfg,
+ httpUpstream: &httpUpstreamRecorder{},
+ cache: &stubGatewayCache{},
+ openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
+ toolCorrector: NewCodexToolCorrector(),
+ openaiWSPool: pool,
+ }
+ account := &Account{
+ ID: 29,
+ Name: "openai-oauth",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "access_token": "oauth-token-1",
+ },
+ Extra: map[string]any{
+ "responses_websockets_v2_enabled": true,
+ },
+ }
+
+ body := []byte(`{"model":"gpt-5.1","stream":false,"store":true,"input":[{"type":"input_text","text":"hello"}]}`)
+ result, err := svc.Forward(context.Background(), c, account, body)
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.Equal(t, "resp_oauth_1", result.RequestID)
+
+ require.NotNil(t, captureConn.lastWrite)
+ requestJSON := requestToJSONString(captureConn.lastWrite)
+ require.True(t, gjson.Get(requestJSON, "store").Exists(), "OAuth WSv2 应显式写入 store 字段")
+ require.False(t, gjson.Get(requestJSON, "store").Bool(), "默认策略应将 OAuth store 置为 false")
+ require.True(t, gjson.Get(requestJSON, "stream").Exists(), "WSv2 payload 应保留 stream 字段")
+ require.True(t, gjson.Get(requestJSON, "stream").Bool(), "OAuth Codex 规范化后应强制 stream=true")
+ require.Equal(t, openAIWSBetaV2Value, captureDialer.lastHeaders.Get("OpenAI-Beta"))
+ require.Equal(t, "sess-oauth-1", captureDialer.lastHeaders.Get("session_id"))
+ require.Equal(t, "conv-oauth-1", captureDialer.lastHeaders.Get("conversation_id"))
+}
+
+func TestOpenAIGatewayService_Forward_WSv2_HeaderSessionFallbackFromPromptCacheKey(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
+ c.Request.Header.Set("User-Agent", "codex_cli_rs/0.98.0")
+
+ cfg := &config.Config{}
+ cfg.Security.URLAllowlist.Enabled = false
+ cfg.Security.URLAllowlist.AllowInsecureHTTP = true
+ cfg.Gateway.OpenAIWS.Enabled = true
+ cfg.Gateway.OpenAIWS.OAuthEnabled = true
+ cfg.Gateway.OpenAIWS.APIKeyEnabled = true
+ cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
+ cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
+ cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
+ cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
+
+ captureConn := &openAIWSCaptureConn{
+ events: [][]byte{
+ []byte(`{"type":"response.completed","response":{"id":"resp_prompt_cache_key","model":"gpt-5.1","usage":{"input_tokens":2,"output_tokens":1}}}`),
+ },
+ }
+ captureDialer := &openAIWSCaptureDialer{conn: captureConn}
+ pool := newOpenAIWSConnPool(cfg)
+ pool.setClientDialerForTest(captureDialer)
+
+ svc := &OpenAIGatewayService{
+ cfg: cfg,
+ httpUpstream: &httpUpstreamRecorder{},
+ cache: &stubGatewayCache{},
+ openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
+ toolCorrector: NewCodexToolCorrector(),
+ openaiWSPool: pool,
+ }
+ account := &Account{
+ ID: 31,
+ Name: "openai-oauth",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "access_token": "oauth-token-1",
+ },
+ Extra: map[string]any{
+ "responses_websockets_v2_enabled": true,
+ },
+ }
+
+ body := []byte(`{"model":"gpt-5.1","stream":true,"prompt_cache_key":"pcache_123","input":[{"type":"input_text","text":"hi"}]}`)
+ result, err := svc.Forward(context.Background(), c, account, body)
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.Equal(t, "resp_prompt_cache_key", result.RequestID)
+
+ require.Equal(t, "pcache_123", captureDialer.lastHeaders.Get("session_id"))
+ require.Empty(t, captureDialer.lastHeaders.Get("conversation_id"))
+ require.NotNil(t, captureConn.lastWrite)
+ require.True(t, gjson.Get(requestToJSONString(captureConn.lastWrite), "stream").Exists())
+}
+
+func TestOpenAIGatewayService_Forward_WSv1_Unsupported(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
+ c.Request.Header.Set("User-Agent", "codex_cli_rs/0.98.0")
+
+ cfg := &config.Config{}
+ cfg.Security.URLAllowlist.Enabled = false
+ cfg.Security.URLAllowlist.AllowInsecureHTTP = true
+ cfg.Gateway.OpenAIWS.Enabled = true
+ cfg.Gateway.OpenAIWS.OAuthEnabled = true
+ cfg.Gateway.OpenAIWS.APIKeyEnabled = true
+ cfg.Gateway.OpenAIWS.ResponsesWebsockets = true
+ cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = false
+
+ upstream := &httpUpstreamRecorder{
+ resp: &http.Response{
+ StatusCode: http.StatusOK,
+ Header: http.Header{"Content-Type": []string{"application/json"}},
+ Body: io.NopCloser(strings.NewReader(`{"usage":{"input_tokens":1,"output_tokens":1}}`)),
+ },
+ }
+
+ svc := &OpenAIGatewayService{
+ cfg: cfg,
+ httpUpstream: upstream,
+ cache: &stubGatewayCache{},
+ openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
+ toolCorrector: NewCodexToolCorrector(),
+ }
+
+ account := &Account{
+ ID: 39,
+ Name: "openai-ws-v1",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "api_key": "sk-test",
+ "base_url": "https://api.openai.com/v1/responses",
+ },
+ Extra: map[string]any{
+ "responses_websockets_v2_enabled": true,
+ },
+ }
+
+ body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_prev_v1","input":[{"type":"input_text","text":"hello"}]}`)
+ result, err := svc.Forward(context.Background(), c, account, body)
+ require.Error(t, err)
+ require.Nil(t, result)
+ require.Contains(t, err.Error(), "ws v1")
+ require.Equal(t, http.StatusBadRequest, rec.Code)
+ require.Contains(t, rec.Body.String(), "WSv1")
+ require.Nil(t, upstream.lastReq, "WSv1 不支持时不应触发 HTTP 上游请求")
+}
+
+func TestOpenAIGatewayService_Forward_WSv2_TurnStateAndMetadataReplayOnReconnect(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ var connIndex atomic.Int64
+ headersCh := make(chan http.Header, 4)
+ upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
+ wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ idx := connIndex.Add(1)
+ headersCh <- cloneHeader(r.Header)
+
+ respHeader := http.Header{}
+ if idx == 1 {
+ respHeader.Set("x-codex-turn-state", "turn_state_first")
+ }
+ conn, err := upgrader.Upgrade(w, r, respHeader)
+ if err != nil {
+ t.Errorf("upgrade websocket failed: %v", err)
+ return
+ }
+ defer func() {
+ _ = conn.Close()
+ }()
+
+ var request map[string]any
+ if err := conn.ReadJSON(&request); err != nil {
+ t.Errorf("read ws request failed: %v", err)
+ return
+ }
+ responseID := "resp_turn_" + strconv.FormatInt(idx, 10)
+ if err := conn.WriteJSON(map[string]any{
+ "type": "response.completed",
+ "response": map[string]any{
+ "id": responseID,
+ "model": "gpt-5.1",
+ "usage": map[string]any{
+ "input_tokens": 2,
+ "output_tokens": 1,
+ },
+ },
+ }); err != nil {
+ t.Errorf("write response.completed failed: %v", err)
+ return
+ }
+ }))
+ defer wsServer.Close()
+
+ cfg := &config.Config{}
+ cfg.Security.URLAllowlist.Enabled = false
+ cfg.Security.URLAllowlist.AllowInsecureHTTP = true
+ cfg.Gateway.OpenAIWS.Enabled = true
+ cfg.Gateway.OpenAIWS.OAuthEnabled = true
+ cfg.Gateway.OpenAIWS.APIKeyEnabled = true
+ cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
+ cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
+ cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
+ cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 0
+
+ svc := &OpenAIGatewayService{
+ cfg: cfg,
+ httpUpstream: &httpUpstreamRecorder{},
+ cache: &stubGatewayCache{},
+ openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
+ toolCorrector: NewCodexToolCorrector(),
+ }
+
+ account := &Account{
+ ID: 49,
+ Name: "openai-turn-state",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "api_key": "sk-test",
+ "base_url": wsServer.URL,
+ },
+ Extra: map[string]any{
+ "responses_websockets_v2_enabled": true,
+ },
+ }
+
+ reqBody := []byte(`{"model":"gpt-5.1","stream":false,"input":[{"type":"input_text","text":"hello"}]}`)
+ rec1 := httptest.NewRecorder()
+ c1, _ := gin.CreateTestContext(rec1)
+ c1.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
+ c1.Request.Header.Set("session_id", "session_turn_state")
+ c1.Request.Header.Set("x-codex-turn-metadata", "turn_meta_1")
+ result1, err := svc.Forward(context.Background(), c1, account, reqBody)
+ require.NoError(t, err)
+ require.NotNil(t, result1)
+
+ sessionHash := svc.GenerateSessionHash(c1, reqBody)
+ store := svc.getOpenAIWSStateStore()
+ turnState, ok := store.GetSessionTurnState(0, sessionHash)
+ require.True(t, ok)
+ require.Equal(t, "turn_state_first", turnState)
+
+ // 主动淘汰连接,模拟下一次请求发生重连。
+ connID, hasConn := store.GetResponseConn(result1.RequestID)
+ require.True(t, hasConn)
+ svc.getOpenAIWSConnPool().evictConn(account.ID, connID)
+
+ rec2 := httptest.NewRecorder()
+ c2, _ := gin.CreateTestContext(rec2)
+ c2.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
+ c2.Request.Header.Set("session_id", "session_turn_state")
+ c2.Request.Header.Set("x-codex-turn-metadata", "turn_meta_2")
+ result2, err := svc.Forward(context.Background(), c2, account, reqBody)
+ require.NoError(t, err)
+ require.NotNil(t, result2)
+
+ firstHandshakeHeaders := <-headersCh
+ secondHandshakeHeaders := <-headersCh
+ require.Equal(t, "turn_meta_1", firstHandshakeHeaders.Get("X-Codex-Turn-Metadata"))
+ require.Equal(t, "turn_meta_2", secondHandshakeHeaders.Get("X-Codex-Turn-Metadata"))
+ require.Equal(t, "turn_state_first", secondHandshakeHeaders.Get("X-Codex-Turn-State"))
+}
+
+func TestOpenAIGatewayService_Forward_WSv2_GeneratePrewarm(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
+ c.Request.Header.Set("session_id", "session-prewarm")
+
+ cfg := &config.Config{}
+ cfg.Security.URLAllowlist.Enabled = false
+ cfg.Security.URLAllowlist.AllowInsecureHTTP = true
+ cfg.Gateway.OpenAIWS.Enabled = true
+ cfg.Gateway.OpenAIWS.OAuthEnabled = true
+ cfg.Gateway.OpenAIWS.APIKeyEnabled = true
+ cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
+ cfg.Gateway.OpenAIWS.PrewarmGenerateEnabled = true
+ cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
+ cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
+ cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
+
+ captureConn := &openAIWSCaptureConn{
+ events: [][]byte{
+ []byte(`{"type":"response.completed","response":{"id":"resp_prewarm_1","model":"gpt-5.1","usage":{"input_tokens":0,"output_tokens":0}}}`),
+ []byte(`{"type":"response.completed","response":{"id":"resp_main_1","model":"gpt-5.1","usage":{"input_tokens":4,"output_tokens":2}}}`),
+ },
+ }
+ captureDialer := &openAIWSCaptureDialer{conn: captureConn}
+ pool := newOpenAIWSConnPool(cfg)
+ pool.setClientDialerForTest(captureDialer)
+
+ svc := &OpenAIGatewayService{
+ cfg: cfg,
+ httpUpstream: &httpUpstreamRecorder{},
+ cache: &stubGatewayCache{},
+ openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
+ toolCorrector: NewCodexToolCorrector(),
+ openaiWSPool: pool,
+ }
+
+ account := &Account{
+ ID: 59,
+ Name: "openai-prewarm",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "api_key": "sk-test",
+ },
+ Extra: map[string]any{
+ "responses_websockets_v2_enabled": true,
+ },
+ }
+
+ body := []byte(`{"model":"gpt-5.1","stream":false,"input":[{"type":"input_text","text":"hello"}]}`)
+ result, err := svc.Forward(context.Background(), c, account, body)
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.Equal(t, "resp_main_1", result.RequestID)
+
+ require.Len(t, captureConn.writes, 2, "开启 generate=false 预热后应发送两次 WS 请求")
+ firstWrite := requestToJSONString(captureConn.writes[0])
+ secondWrite := requestToJSONString(captureConn.writes[1])
+ require.True(t, gjson.Get(firstWrite, "generate").Exists())
+ require.False(t, gjson.Get(firstWrite, "generate").Bool())
+ require.False(t, gjson.Get(secondWrite, "generate").Exists())
+}
+
+func TestOpenAIGatewayService_PrewarmReadHonorsParentContext(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.PrewarmGenerateEnabled = true
+ cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 5
+ cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
+
+ svc := &OpenAIGatewayService{
+ cfg: cfg,
+ toolCorrector: NewCodexToolCorrector(),
+ }
+ account := &Account{
+ ID: 601,
+ Name: "openai-prewarm-timeout",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ }
+ conn := newOpenAIWSConn("prewarm_ctx_conn", account.ID, &openAIWSBlockingConn{
+ readDelay: 200 * time.Millisecond,
+ }, nil)
+ lease := &openAIWSConnLease{
+ accountID: account.ID,
+ conn: conn,
+ }
+ payload := map[string]any{
+ "type": "response.create",
+ "model": "gpt-5.1",
+ }
+
+ ctx, cancel := context.WithTimeout(context.Background(), 40*time.Millisecond)
+ defer cancel()
+ start := time.Now()
+ err := svc.performOpenAIWSGeneratePrewarm(
+ ctx,
+ lease,
+ OpenAIWSProtocolDecision{Transport: OpenAIUpstreamTransportResponsesWebsocketV2},
+ payload,
+ "",
+ map[string]any{"model": "gpt-5.1"},
+ account,
+ nil,
+ 0,
+ )
+ elapsed := time.Since(start)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "prewarm_read_event")
+ require.Less(t, elapsed, 180*time.Millisecond, "预热读取应受父 context 取消控制,不应阻塞到 read_timeout")
+}
+
+func TestOpenAIGatewayService_Forward_WSv2_TurnMetadataInPayloadOnConnReuse(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ cfg := &config.Config{}
+ cfg.Security.URLAllowlist.Enabled = false
+ cfg.Security.URLAllowlist.AllowInsecureHTTP = true
+ cfg.Gateway.OpenAIWS.Enabled = true
+ cfg.Gateway.OpenAIWS.OAuthEnabled = true
+ cfg.Gateway.OpenAIWS.APIKeyEnabled = true
+ cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
+ cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
+ cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
+ cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
+
+ captureConn := &openAIWSCaptureConn{
+ events: [][]byte{
+ []byte(`{"type":"response.completed","response":{"id":"resp_meta_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
+ []byte(`{"type":"response.completed","response":{"id":"resp_meta_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
+ },
+ }
+ captureDialer := &openAIWSCaptureDialer{conn: captureConn}
+ pool := newOpenAIWSConnPool(cfg)
+ pool.setClientDialerForTest(captureDialer)
+
+ svc := &OpenAIGatewayService{
+ cfg: cfg,
+ httpUpstream: &httpUpstreamRecorder{},
+ cache: &stubGatewayCache{},
+ openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
+ toolCorrector: NewCodexToolCorrector(),
+ openaiWSPool: pool,
+ }
+
+ account := &Account{
+ ID: 69,
+ Name: "openai-turn-metadata",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "api_key": "sk-test",
+ },
+ Extra: map[string]any{
+ "responses_websockets_v2_enabled": true,
+ },
+ }
+
+ body := []byte(`{"model":"gpt-5.1","stream":false,"input":[{"type":"input_text","text":"hello"}]}`)
+
+ rec1 := httptest.NewRecorder()
+ c1, _ := gin.CreateTestContext(rec1)
+ c1.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
+ c1.Request.Header.Set("session_id", "session-metadata-reuse")
+ c1.Request.Header.Set("x-codex-turn-metadata", "turn_meta_payload_1")
+ result1, err := svc.Forward(context.Background(), c1, account, body)
+ require.NoError(t, err)
+ require.NotNil(t, result1)
+ require.Equal(t, "resp_meta_1", result1.RequestID)
+
+ rec2 := httptest.NewRecorder()
+ c2, _ := gin.CreateTestContext(rec2)
+ c2.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
+ c2.Request.Header.Set("session_id", "session-metadata-reuse")
+ c2.Request.Header.Set("x-codex-turn-metadata", "turn_meta_payload_2")
+ result2, err := svc.Forward(context.Background(), c2, account, body)
+ require.NoError(t, err)
+ require.NotNil(t, result2)
+ require.Equal(t, "resp_meta_2", result2.RequestID)
+
+ require.Equal(t, 1, captureDialer.DialCount(), "同一账号两轮请求应复用同一 WS 连接")
+ require.Len(t, captureConn.writes, 2)
+
+ firstWrite := requestToJSONString(captureConn.writes[0])
+ secondWrite := requestToJSONString(captureConn.writes[1])
+ require.Equal(t, "turn_meta_payload_1", gjson.Get(firstWrite, "client_metadata.x-codex-turn-metadata").String())
+ require.Equal(t, "turn_meta_payload_2", gjson.Get(secondWrite, "client_metadata.x-codex-turn-metadata").String())
+}
+
+func TestOpenAIGatewayService_Forward_WSv2StoreFalseSessionConnIsolation(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ var upgradeCount atomic.Int64
+ var sequence atomic.Int64
+ upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
+ wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ upgradeCount.Add(1)
+ conn, err := upgrader.Upgrade(w, r, nil)
+ if err != nil {
+ t.Errorf("upgrade websocket failed: %v", err)
+ return
+ }
+ defer func() {
+ _ = conn.Close()
+ }()
+
+ for {
+ var request map[string]any
+ if err := conn.ReadJSON(&request); err != nil {
+ return
+ }
+ responseID := "resp_store_false_" + strconv.FormatInt(sequence.Add(1), 10)
+ if err := conn.WriteJSON(map[string]any{
+ "type": "response.completed",
+ "response": map[string]any{
+ "id": responseID,
+ "model": "gpt-5.1",
+ "usage": map[string]any{
+ "input_tokens": 1,
+ "output_tokens": 1,
+ },
+ },
+ }); err != nil {
+ return
+ }
+ }
+ }))
+ defer wsServer.Close()
+
+ cfg := &config.Config{}
+ cfg.Security.URLAllowlist.Enabled = false
+ cfg.Security.URLAllowlist.AllowInsecureHTTP = true
+ cfg.Gateway.OpenAIWS.Enabled = true
+ cfg.Gateway.OpenAIWS.OAuthEnabled = true
+ cfg.Gateway.OpenAIWS.APIKeyEnabled = true
+ cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
+ cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 4
+ cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
+ cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 4
+ cfg.Gateway.OpenAIWS.StoreDisabledForceNewConn = true
+
+ svc := &OpenAIGatewayService{
+ cfg: cfg,
+ httpUpstream: &httpUpstreamRecorder{},
+ cache: &stubGatewayCache{},
+ openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
+ toolCorrector: NewCodexToolCorrector(),
+ }
+
+ account := &Account{
+ ID: 79,
+ Name: "openai-store-false",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 2,
+ Credentials: map[string]any{
+ "api_key": "sk-test",
+ "base_url": wsServer.URL,
+ },
+ Extra: map[string]any{
+ "responses_websockets_v2_enabled": true,
+ },
+ }
+
+ body := []byte(`{"model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"input_text","text":"hello"}]}`)
+
+ rec1 := httptest.NewRecorder()
+ c1, _ := gin.CreateTestContext(rec1)
+ c1.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
+ c1.Request.Header.Set("session_id", "session_store_false_a")
+ result1, err := svc.Forward(context.Background(), c1, account, body)
+ require.NoError(t, err)
+ require.NotNil(t, result1)
+ require.Equal(t, int64(1), upgradeCount.Load())
+
+ rec2 := httptest.NewRecorder()
+ c2, _ := gin.CreateTestContext(rec2)
+ c2.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
+ c2.Request.Header.Set("session_id", "session_store_false_a")
+ result2, err := svc.Forward(context.Background(), c2, account, body)
+ require.NoError(t, err)
+ require.NotNil(t, result2)
+ require.Equal(t, int64(1), upgradeCount.Load(), "同一 session(store=false) 应复用同一 WS 连接")
+
+ rec3 := httptest.NewRecorder()
+ c3, _ := gin.CreateTestContext(rec3)
+ c3.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
+ c3.Request.Header.Set("session_id", "session_store_false_b")
+ result3, err := svc.Forward(context.Background(), c3, account, body)
+ require.NoError(t, err)
+ require.NotNil(t, result3)
+ require.Equal(t, int64(2), upgradeCount.Load(), "不同 session(store=false) 应隔离连接,避免续链状态互相覆盖")
+}
+
+func TestOpenAIGatewayService_Forward_WSv2StoreFalseDisableForceNewConnAllowsReuse(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ var upgradeCount atomic.Int64
+ var sequence atomic.Int64
+ upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
+ wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ upgradeCount.Add(1)
+ conn, err := upgrader.Upgrade(w, r, nil)
+ if err != nil {
+ t.Errorf("upgrade websocket failed: %v", err)
+ return
+ }
+ defer func() {
+ _ = conn.Close()
+ }()
+
+ for {
+ var request map[string]any
+ if err := conn.ReadJSON(&request); err != nil {
+ return
+ }
+ responseID := "resp_store_false_reuse_" + strconv.FormatInt(sequence.Add(1), 10)
+ if err := conn.WriteJSON(map[string]any{
+ "type": "response.completed",
+ "response": map[string]any{
+ "id": responseID,
+ "model": "gpt-5.1",
+ "usage": map[string]any{
+ "input_tokens": 1,
+ "output_tokens": 1,
+ },
+ },
+ }); err != nil {
+ return
+ }
+ }
+ }))
+ defer wsServer.Close()
+
+ cfg := &config.Config{}
+ cfg.Security.URLAllowlist.Enabled = false
+ cfg.Security.URLAllowlist.AllowInsecureHTTP = true
+ cfg.Gateway.OpenAIWS.Enabled = true
+ cfg.Gateway.OpenAIWS.OAuthEnabled = true
+ cfg.Gateway.OpenAIWS.APIKeyEnabled = true
+ cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
+ cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
+ cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
+ cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
+ cfg.Gateway.OpenAIWS.StoreDisabledForceNewConn = false
+
+ svc := &OpenAIGatewayService{
+ cfg: cfg,
+ httpUpstream: &httpUpstreamRecorder{},
+ cache: &stubGatewayCache{},
+ openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
+ toolCorrector: NewCodexToolCorrector(),
+ }
+
+ account := &Account{
+ ID: 80,
+ Name: "openai-store-false-reuse",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 2,
+ Credentials: map[string]any{
+ "api_key": "sk-test",
+ "base_url": wsServer.URL,
+ },
+ Extra: map[string]any{
+ "responses_websockets_v2_enabled": true,
+ },
+ }
+
+ body := []byte(`{"model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"input_text","text":"hello"}]}`)
+
+ rec1 := httptest.NewRecorder()
+ c1, _ := gin.CreateTestContext(rec1)
+ c1.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
+ c1.Request.Header.Set("session_id", "session_store_false_reuse_a")
+ result1, err := svc.Forward(context.Background(), c1, account, body)
+ require.NoError(t, err)
+ require.NotNil(t, result1)
+ require.Equal(t, int64(1), upgradeCount.Load())
+
+ rec2 := httptest.NewRecorder()
+ c2, _ := gin.CreateTestContext(rec2)
+ c2.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
+ c2.Request.Header.Set("session_id", "session_store_false_reuse_b")
+ result2, err := svc.Forward(context.Background(), c2, account, body)
+ require.NoError(t, err)
+ require.NotNil(t, result2)
+ require.Equal(t, int64(1), upgradeCount.Load(), "关闭强制新连后,不同 session(store=false) 可复用连接")
+}
+
+func TestOpenAIGatewayService_Forward_WSv2ReadTimeoutAppliesPerRead(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
+ c.Request.Header.Set("User-Agent", "codex_cli_rs/0.98.0")
+
+ cfg := &config.Config{}
+ cfg.Security.URLAllowlist.Enabled = false
+ cfg.Security.URLAllowlist.AllowInsecureHTTP = true
+ cfg.Gateway.OpenAIWS.Enabled = true
+ cfg.Gateway.OpenAIWS.OAuthEnabled = true
+ cfg.Gateway.OpenAIWS.APIKeyEnabled = true
+ cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
+ cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
+ cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
+ cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
+ cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
+ cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
+ cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 1
+ cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
+
+ captureConn := &openAIWSCaptureConn{
+ readDelays: []time.Duration{
+ 700 * time.Millisecond,
+ 700 * time.Millisecond,
+ },
+ events: [][]byte{
+ []byte(`{"type":"response.created","response":{"id":"resp_timeout_ok","model":"gpt-5.1"}}`),
+ []byte(`{"type":"response.completed","response":{"id":"resp_timeout_ok","model":"gpt-5.1","usage":{"input_tokens":2,"output_tokens":1}}}`),
+ },
+ }
+ captureDialer := &openAIWSCaptureDialer{conn: captureConn}
+ pool := newOpenAIWSConnPool(cfg)
+ pool.setClientDialerForTest(captureDialer)
+
+ upstream := &httpUpstreamRecorder{
+ resp: &http.Response{
+ StatusCode: http.StatusOK,
+ Header: http.Header{"Content-Type": []string{"application/json"}},
+ Body: io.NopCloser(strings.NewReader(`{"id":"resp_http_fallback","usage":{"input_tokens":1,"output_tokens":1}}`)),
+ },
+ }
+
+ svc := &OpenAIGatewayService{
+ cfg: cfg,
+ httpUpstream: upstream,
+ cache: &stubGatewayCache{},
+ openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
+ toolCorrector: NewCodexToolCorrector(),
+ openaiWSPool: pool,
+ }
+
+ account := &Account{
+ ID: 81,
+ Name: "openai-read-timeout",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "api_key": "sk-test",
+ },
+ Extra: map[string]any{
+ "responses_websockets_v2_enabled": true,
+ },
+ }
+
+ body := []byte(`{"model":"gpt-5.1","stream":false,"input":[{"type":"input_text","text":"hello"}]}`)
+ result, err := svc.Forward(context.Background(), c, account, body)
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.Equal(t, "resp_timeout_ok", result.RequestID)
+ require.Nil(t, upstream.lastReq, "每次 Read 都应独立应用超时;总时长超过 read_timeout 不应误回退 HTTP")
+}
+
+type openAIWSCaptureDialer struct {
+ mu sync.Mutex
+ conn *openAIWSCaptureConn
+ lastHeaders http.Header
+ handshake http.Header
+ dialCount int
+}
+
+func (d *openAIWSCaptureDialer) Dial(
+ ctx context.Context,
+ wsURL string,
+ headers http.Header,
+ proxyURL string,
+) (openAIWSClientConn, int, http.Header, error) {
+ _ = ctx
+ _ = wsURL
+ _ = proxyURL
+ d.mu.Lock()
+ d.lastHeaders = cloneHeader(headers)
+ d.dialCount++
+ respHeaders := cloneHeader(d.handshake)
+ d.mu.Unlock()
+ return d.conn, 0, respHeaders, nil
+}
+
+func (d *openAIWSCaptureDialer) DialCount() int {
+ d.mu.Lock()
+ defer d.mu.Unlock()
+ return d.dialCount
+}
+
+type openAIWSCaptureConn struct {
+ mu sync.Mutex
+ readDelays []time.Duration
+ events [][]byte
+ lastWrite map[string]any
+ writes []map[string]any
+ closed bool
+}
+
+func (c *openAIWSCaptureConn) WriteJSON(ctx context.Context, value any) error {
+ _ = ctx
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ if c.closed {
+ return errOpenAIWSConnClosed
+ }
+ switch payload := value.(type) {
+ case map[string]any:
+ c.lastWrite = cloneMapStringAny(payload)
+ c.writes = append(c.writes, cloneMapStringAny(payload))
+ case json.RawMessage:
+ var parsed map[string]any
+ if err := json.Unmarshal(payload, &parsed); err == nil {
+ c.lastWrite = cloneMapStringAny(parsed)
+ c.writes = append(c.writes, cloneMapStringAny(parsed))
+ }
+ case []byte:
+ var parsed map[string]any
+ if err := json.Unmarshal(payload, &parsed); err == nil {
+ c.lastWrite = cloneMapStringAny(parsed)
+ c.writes = append(c.writes, cloneMapStringAny(parsed))
+ }
+ }
+ return nil
+}
+
+func (c *openAIWSCaptureConn) ReadMessage(ctx context.Context) ([]byte, error) {
+ if ctx == nil {
+ ctx = context.Background()
+ }
+ c.mu.Lock()
+ if c.closed {
+ c.mu.Unlock()
+ return nil, errOpenAIWSConnClosed
+ }
+ if len(c.events) == 0 {
+ c.mu.Unlock()
+ return nil, io.EOF
+ }
+ delay := time.Duration(0)
+ if len(c.readDelays) > 0 {
+ delay = c.readDelays[0]
+ c.readDelays = c.readDelays[1:]
+ }
+ event := c.events[0]
+ c.events = c.events[1:]
+ c.mu.Unlock()
+ if delay > 0 {
+ timer := time.NewTimer(delay)
+ defer timer.Stop()
+ select {
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ case <-timer.C:
+ }
+ }
+ return event, nil
+}
+
+func (c *openAIWSCaptureConn) Ping(ctx context.Context) error {
+ _ = ctx
+ return nil
+}
+
+func (c *openAIWSCaptureConn) Close() error {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ c.closed = true
+ return nil
+}
+
+func cloneMapStringAny(src map[string]any) map[string]any {
+ if src == nil {
+ return nil
+ }
+ dst := make(map[string]any, len(src))
+ for k, v := range src {
+ dst[k] = v
+ }
+ return dst
+}
diff --git a/backend/internal/service/openai_ws_pool.go b/backend/internal/service/openai_ws_pool.go
new file mode 100644
index 00000000..db6a96a7
--- /dev/null
+++ b/backend/internal/service/openai_ws_pool.go
@@ -0,0 +1,1706 @@
+package service
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "math"
+ "net/http"
+ "sort"
+ "strconv"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "golang.org/x/sync/errgroup"
+)
+
+const (
+ openAIWSConnMaxAge = 60 * time.Minute
+ openAIWSConnHealthCheckIdle = 90 * time.Second
+ openAIWSConnHealthCheckTO = 2 * time.Second
+ openAIWSConnPrewarmExtraDelay = 2 * time.Second
+ openAIWSAcquireCleanupInterval = 3 * time.Second
+ openAIWSBackgroundPingInterval = 30 * time.Second
+ openAIWSBackgroundSweepTicker = 30 * time.Second
+
+ openAIWSPrewarmFailureWindow = 30 * time.Second
+ openAIWSPrewarmFailureSuppress = 2
+)
+
+var (
+ errOpenAIWSConnClosed = errors.New("openai ws connection closed")
+ errOpenAIWSConnQueueFull = errors.New("openai ws connection queue full")
+ errOpenAIWSPreferredConnUnavailable = errors.New("openai ws preferred connection unavailable")
+)
+
+type openAIWSDialError struct {
+ StatusCode int
+ ResponseHeaders http.Header
+ Err error
+}
+
+func (e *openAIWSDialError) Error() string {
+ if e == nil {
+ return ""
+ }
+ if e.StatusCode > 0 {
+ return fmt.Sprintf("openai ws dial failed: status=%d err=%v", e.StatusCode, e.Err)
+ }
+ return fmt.Sprintf("openai ws dial failed: %v", e.Err)
+}
+
+func (e *openAIWSDialError) Unwrap() error {
+ if e == nil {
+ return nil
+ }
+ return e.Err
+}
+
+type openAIWSAcquireRequest struct {
+ Account *Account
+ WSURL string
+ Headers http.Header
+ ProxyURL string
+ PreferredConnID string
+ // ForceNewConn: 强制本次获取新连接(避免复用导致连接内续链状态互相污染)。
+ ForceNewConn bool
+ // ForcePreferredConn: 强制本次只使用 PreferredConnID,禁止漂移到其它连接。
+ ForcePreferredConn bool
+}
+
+type openAIWSConnLease struct {
+ pool *openAIWSConnPool
+ accountID int64
+ conn *openAIWSConn
+ queueWait time.Duration
+ connPick time.Duration
+ reused bool
+ released atomic.Bool
+}
+
+func (l *openAIWSConnLease) activeConn() (*openAIWSConn, error) {
+ if l == nil || l.conn == nil {
+ return nil, errOpenAIWSConnClosed
+ }
+ if l.released.Load() {
+ return nil, errOpenAIWSConnClosed
+ }
+ return l.conn, nil
+}
+
+func (l *openAIWSConnLease) ConnID() string {
+ if l == nil || l.conn == nil {
+ return ""
+ }
+ return l.conn.id
+}
+
+func (l *openAIWSConnLease) QueueWaitDuration() time.Duration {
+ if l == nil {
+ return 0
+ }
+ return l.queueWait
+}
+
+func (l *openAIWSConnLease) ConnPickDuration() time.Duration {
+ if l == nil {
+ return 0
+ }
+ return l.connPick
+}
+
+func (l *openAIWSConnLease) Reused() bool {
+ if l == nil {
+ return false
+ }
+ return l.reused
+}
+
+func (l *openAIWSConnLease) HandshakeHeader(name string) string {
+ if l == nil || l.conn == nil {
+ return ""
+ }
+ return l.conn.handshakeHeader(name)
+}
+
+func (l *openAIWSConnLease) IsPrewarmed() bool {
+ if l == nil || l.conn == nil {
+ return false
+ }
+ return l.conn.isPrewarmed()
+}
+
+func (l *openAIWSConnLease) MarkPrewarmed() {
+ if l == nil || l.conn == nil {
+ return
+ }
+ l.conn.markPrewarmed()
+}
+
+func (l *openAIWSConnLease) WriteJSON(value any, timeout time.Duration) error {
+ conn, err := l.activeConn()
+ if err != nil {
+ return err
+ }
+ return conn.writeJSONWithTimeout(context.Background(), value, timeout)
+}
+
+func (l *openAIWSConnLease) WriteJSONWithContextTimeout(ctx context.Context, value any, timeout time.Duration) error {
+ conn, err := l.activeConn()
+ if err != nil {
+ return err
+ }
+ return conn.writeJSONWithTimeout(ctx, value, timeout)
+}
+
+func (l *openAIWSConnLease) WriteJSONContext(ctx context.Context, value any) error {
+ conn, err := l.activeConn()
+ if err != nil {
+ return err
+ }
+ return conn.writeJSON(value, ctx)
+}
+
+func (l *openAIWSConnLease) ReadMessage(timeout time.Duration) ([]byte, error) {
+ conn, err := l.activeConn()
+ if err != nil {
+ return nil, err
+ }
+ return conn.readMessageWithTimeout(timeout)
+}
+
+func (l *openAIWSConnLease) ReadMessageContext(ctx context.Context) ([]byte, error) {
+ conn, err := l.activeConn()
+ if err != nil {
+ return nil, err
+ }
+ return conn.readMessage(ctx)
+}
+
+func (l *openAIWSConnLease) ReadMessageWithContextTimeout(ctx context.Context, timeout time.Duration) ([]byte, error) {
+ conn, err := l.activeConn()
+ if err != nil {
+ return nil, err
+ }
+ return conn.readMessageWithContextTimeout(ctx, timeout)
+}
+
+func (l *openAIWSConnLease) PingWithTimeout(timeout time.Duration) error {
+ conn, err := l.activeConn()
+ if err != nil {
+ return err
+ }
+ return conn.pingWithTimeout(timeout)
+}
+
+func (l *openAIWSConnLease) MarkBroken() {
+ if l == nil || l.pool == nil || l.conn == nil || l.released.Load() {
+ return
+ }
+ l.pool.evictConn(l.accountID, l.conn.id)
+}
+
+func (l *openAIWSConnLease) Release() {
+ if l == nil || l.conn == nil {
+ return
+ }
+ if !l.released.CompareAndSwap(false, true) {
+ return
+ }
+ l.conn.release()
+}
+
+type openAIWSConn struct {
+ id string
+ ws openAIWSClientConn
+
+ handshakeHeaders http.Header
+
+ leaseCh chan struct{}
+ closedCh chan struct{}
+ closeOnce sync.Once
+
+ readMu sync.Mutex
+ writeMu sync.Mutex
+
+ waiters atomic.Int32
+ createdAtNano atomic.Int64
+ lastUsedNano atomic.Int64
+ prewarmed atomic.Bool
+}
+
+func newOpenAIWSConn(id string, _ int64, ws openAIWSClientConn, handshakeHeaders http.Header) *openAIWSConn {
+ now := time.Now()
+ conn := &openAIWSConn{
+ id: id,
+ ws: ws,
+ handshakeHeaders: cloneHeader(handshakeHeaders),
+ leaseCh: make(chan struct{}, 1),
+ closedCh: make(chan struct{}),
+ }
+ conn.leaseCh <- struct{}{}
+ conn.createdAtNano.Store(now.UnixNano())
+ conn.lastUsedNano.Store(now.UnixNano())
+ return conn
+}
+
+func (c *openAIWSConn) tryAcquire() bool {
+ if c == nil {
+ return false
+ }
+ select {
+ case <-c.closedCh:
+ return false
+ default:
+ }
+ select {
+ case <-c.leaseCh:
+ select {
+ case <-c.closedCh:
+ c.release()
+ return false
+ default:
+ }
+ return true
+ default:
+ return false
+ }
+}
+
+func (c *openAIWSConn) acquire(ctx context.Context) error {
+ if c == nil {
+ return errOpenAIWSConnClosed
+ }
+ for {
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ case <-c.closedCh:
+ return errOpenAIWSConnClosed
+ case <-c.leaseCh:
+ select {
+ case <-c.closedCh:
+ c.release()
+ return errOpenAIWSConnClosed
+ default:
+ }
+ return nil
+ }
+ }
+}
+
+func (c *openAIWSConn) release() {
+ if c == nil {
+ return
+ }
+ select {
+ case c.leaseCh <- struct{}{}:
+ default:
+ }
+ c.touch()
+}
+
+func (c *openAIWSConn) close() {
+ if c == nil {
+ return
+ }
+ c.closeOnce.Do(func() {
+ close(c.closedCh)
+ if c.ws != nil {
+ _ = c.ws.Close()
+ }
+ select {
+ case c.leaseCh <- struct{}{}:
+ default:
+ }
+ })
+}
+
+func (c *openAIWSConn) writeJSONWithTimeout(parent context.Context, value any, timeout time.Duration) error {
+ if c == nil {
+ return errOpenAIWSConnClosed
+ }
+ select {
+ case <-c.closedCh:
+ return errOpenAIWSConnClosed
+ default:
+ }
+
+ writeCtx := parent
+ if writeCtx == nil {
+ writeCtx = context.Background()
+ }
+ if timeout <= 0 {
+ return c.writeJSON(value, writeCtx)
+ }
+ var cancel context.CancelFunc
+ writeCtx, cancel = context.WithTimeout(writeCtx, timeout)
+ defer cancel()
+ return c.writeJSON(value, writeCtx)
+}
+
+func (c *openAIWSConn) writeJSON(value any, writeCtx context.Context) error {
+ c.writeMu.Lock()
+ defer c.writeMu.Unlock()
+ if c.ws == nil {
+ return errOpenAIWSConnClosed
+ }
+ if writeCtx == nil {
+ writeCtx = context.Background()
+ }
+ if err := c.ws.WriteJSON(writeCtx, value); err != nil {
+ return err
+ }
+ c.touch()
+ return nil
+}
+
+func (c *openAIWSConn) readMessageWithTimeout(timeout time.Duration) ([]byte, error) {
+ return c.readMessageWithContextTimeout(context.Background(), timeout)
+}
+
+func (c *openAIWSConn) readMessageWithContextTimeout(parent context.Context, timeout time.Duration) ([]byte, error) {
+ if c == nil {
+ return nil, errOpenAIWSConnClosed
+ }
+ select {
+ case <-c.closedCh:
+ return nil, errOpenAIWSConnClosed
+ default:
+ }
+
+ if parent == nil {
+ parent = context.Background()
+ }
+ if timeout <= 0 {
+ return c.readMessage(parent)
+ }
+ readCtx, cancel := context.WithTimeout(parent, timeout)
+ defer cancel()
+ return c.readMessage(readCtx)
+}
+
+func (c *openAIWSConn) readMessage(readCtx context.Context) ([]byte, error) {
+ c.readMu.Lock()
+ defer c.readMu.Unlock()
+ if c.ws == nil {
+ return nil, errOpenAIWSConnClosed
+ }
+ if readCtx == nil {
+ readCtx = context.Background()
+ }
+ payload, err := c.ws.ReadMessage(readCtx)
+ if err != nil {
+ return nil, err
+ }
+ c.touch()
+ return payload, nil
+}
+
+func (c *openAIWSConn) pingWithTimeout(timeout time.Duration) error {
+ if c == nil {
+ return errOpenAIWSConnClosed
+ }
+ select {
+ case <-c.closedCh:
+ return errOpenAIWSConnClosed
+ default:
+ }
+
+ c.writeMu.Lock()
+ defer c.writeMu.Unlock()
+ if c.ws == nil {
+ return errOpenAIWSConnClosed
+ }
+ if timeout <= 0 {
+ timeout = openAIWSConnHealthCheckTO
+ }
+ pingCtx, cancel := context.WithTimeout(context.Background(), timeout)
+ defer cancel()
+ if err := c.ws.Ping(pingCtx); err != nil {
+ return err
+ }
+ return nil
+}
+
+func (c *openAIWSConn) touch() {
+ if c == nil {
+ return
+ }
+ c.lastUsedNano.Store(time.Now().UnixNano())
+}
+
+func (c *openAIWSConn) createdAt() time.Time {
+ if c == nil {
+ return time.Time{}
+ }
+ nano := c.createdAtNano.Load()
+ if nano <= 0 {
+ return time.Time{}
+ }
+ return time.Unix(0, nano)
+}
+
+func (c *openAIWSConn) lastUsedAt() time.Time {
+ if c == nil {
+ return time.Time{}
+ }
+ nano := c.lastUsedNano.Load()
+ if nano <= 0 {
+ return time.Time{}
+ }
+ return time.Unix(0, nano)
+}
+
+func (c *openAIWSConn) idleDuration(now time.Time) time.Duration {
+ if c == nil {
+ return 0
+ }
+ last := c.lastUsedAt()
+ if last.IsZero() {
+ return 0
+ }
+ return now.Sub(last)
+}
+
+func (c *openAIWSConn) age(now time.Time) time.Duration {
+ if c == nil {
+ return 0
+ }
+ created := c.createdAt()
+ if created.IsZero() {
+ return 0
+ }
+ return now.Sub(created)
+}
+
+func (c *openAIWSConn) isLeased() bool {
+ if c == nil {
+ return false
+ }
+ return len(c.leaseCh) == 0
+}
+
+func (c *openAIWSConn) handshakeHeader(name string) string {
+ if c == nil || c.handshakeHeaders == nil {
+ return ""
+ }
+ return strings.TrimSpace(c.handshakeHeaders.Get(strings.TrimSpace(name)))
+}
+
+func (c *openAIWSConn) isPrewarmed() bool {
+ if c == nil {
+ return false
+ }
+ return c.prewarmed.Load()
+}
+
+func (c *openAIWSConn) markPrewarmed() {
+ if c == nil {
+ return
+ }
+ c.prewarmed.Store(true)
+}
+
+type openAIWSAccountPool struct {
+ mu sync.Mutex
+ conns map[string]*openAIWSConn
+ pinnedConns map[string]int
+ creating int
+ lastCleanupAt time.Time
+ lastAcquire *openAIWSAcquireRequest
+ prewarmActive bool
+ prewarmUntil time.Time
+ prewarmFails int
+ prewarmFailAt time.Time
+}
+
+type OpenAIWSPoolMetricsSnapshot struct {
+ AcquireTotal int64
+ AcquireReuseTotal int64
+ AcquireCreateTotal int64
+ AcquireQueueWaitTotal int64
+ AcquireQueueWaitMsTotal int64
+ ConnPickTotal int64
+ ConnPickMsTotal int64
+ ScaleUpTotal int64
+ ScaleDownTotal int64
+}
+
+type openAIWSPoolMetrics struct {
+ acquireTotal atomic.Int64
+ acquireReuseTotal atomic.Int64
+ acquireCreateTotal atomic.Int64
+ acquireQueueWaitTotal atomic.Int64
+ acquireQueueWaitMs atomic.Int64
+ connPickTotal atomic.Int64
+ connPickMs atomic.Int64
+ scaleUpTotal atomic.Int64
+ scaleDownTotal atomic.Int64
+}
+
+type openAIWSConnPool struct {
+ cfg *config.Config
+ // 通过接口解耦底层 WS 客户端实现,默认使用 coder/websocket。
+ clientDialer openAIWSClientDialer
+
+ accounts sync.Map // key: int64(accountID), value: *openAIWSAccountPool
+ seq atomic.Uint64
+
+ metrics openAIWSPoolMetrics
+
+ workerStopCh chan struct{}
+ workerWg sync.WaitGroup
+ closeOnce sync.Once
+}
+
+func newOpenAIWSConnPool(cfg *config.Config) *openAIWSConnPool {
+ pool := &openAIWSConnPool{
+ cfg: cfg,
+ clientDialer: newDefaultOpenAIWSClientDialer(),
+ workerStopCh: make(chan struct{}),
+ }
+ pool.startBackgroundWorkers()
+ return pool
+}
+
+func (p *openAIWSConnPool) SnapshotMetrics() OpenAIWSPoolMetricsSnapshot {
+ if p == nil {
+ return OpenAIWSPoolMetricsSnapshot{}
+ }
+ return OpenAIWSPoolMetricsSnapshot{
+ AcquireTotal: p.metrics.acquireTotal.Load(),
+ AcquireReuseTotal: p.metrics.acquireReuseTotal.Load(),
+ AcquireCreateTotal: p.metrics.acquireCreateTotal.Load(),
+ AcquireQueueWaitTotal: p.metrics.acquireQueueWaitTotal.Load(),
+ AcquireQueueWaitMsTotal: p.metrics.acquireQueueWaitMs.Load(),
+ ConnPickTotal: p.metrics.connPickTotal.Load(),
+ ConnPickMsTotal: p.metrics.connPickMs.Load(),
+ ScaleUpTotal: p.metrics.scaleUpTotal.Load(),
+ ScaleDownTotal: p.metrics.scaleDownTotal.Load(),
+ }
+}
+
+func (p *openAIWSConnPool) SnapshotTransportMetrics() OpenAIWSTransportMetricsSnapshot {
+ if p == nil {
+ return OpenAIWSTransportMetricsSnapshot{}
+ }
+ if dialer, ok := p.clientDialer.(openAIWSTransportMetricsDialer); ok {
+ return dialer.SnapshotTransportMetrics()
+ }
+ return OpenAIWSTransportMetricsSnapshot{}
+}
+
+func (p *openAIWSConnPool) setClientDialerForTest(dialer openAIWSClientDialer) {
+ if p == nil || dialer == nil {
+ return
+ }
+ p.clientDialer = dialer
+}
+
+// Close 停止后台 worker 并关闭所有空闲连接,应在优雅关闭时调用。
+func (p *openAIWSConnPool) Close() {
+ if p == nil {
+ return
+ }
+ p.closeOnce.Do(func() {
+ if p.workerStopCh != nil {
+ close(p.workerStopCh)
+ }
+ p.workerWg.Wait()
+ // 遍历所有账户池,关闭全部空闲连接。
+ p.accounts.Range(func(key, value any) bool {
+ ap, ok := value.(*openAIWSAccountPool)
+ if !ok || ap == nil {
+ return true
+ }
+ ap.mu.Lock()
+ for _, conn := range ap.conns {
+ if conn != nil && !conn.isLeased() {
+ conn.close()
+ }
+ }
+ ap.mu.Unlock()
+ return true
+ })
+ })
+}
+
+func (p *openAIWSConnPool) startBackgroundWorkers() {
+ if p == nil || p.workerStopCh == nil {
+ return
+ }
+ p.workerWg.Add(2)
+ go func() {
+ defer p.workerWg.Done()
+ p.runBackgroundPingWorker()
+ }()
+ go func() {
+ defer p.workerWg.Done()
+ p.runBackgroundCleanupWorker()
+ }()
+}
+
+type openAIWSIdlePingCandidate struct {
+ accountID int64
+ conn *openAIWSConn
+}
+
+func (p *openAIWSConnPool) runBackgroundPingWorker() {
+ if p == nil {
+ return
+ }
+ ticker := time.NewTicker(openAIWSBackgroundPingInterval)
+ defer ticker.Stop()
+ for {
+ select {
+ case <-ticker.C:
+ p.runBackgroundPingSweep()
+ case <-p.workerStopCh:
+ return
+ }
+ }
+}
+
+func (p *openAIWSConnPool) runBackgroundPingSweep() {
+ if p == nil {
+ return
+ }
+ candidates := p.snapshotIdleConnsForPing()
+ var g errgroup.Group
+ g.SetLimit(10)
+ for _, item := range candidates {
+ item := item
+ if item.conn == nil || item.conn.isLeased() || item.conn.waiters.Load() > 0 {
+ continue
+ }
+ g.Go(func() error {
+ if err := item.conn.pingWithTimeout(openAIWSConnHealthCheckTO); err != nil {
+ p.evictConn(item.accountID, item.conn.id)
+ }
+ return nil
+ })
+ }
+ _ = g.Wait()
+}
+
+func (p *openAIWSConnPool) snapshotIdleConnsForPing() []openAIWSIdlePingCandidate {
+ if p == nil {
+ return nil
+ }
+ candidates := make([]openAIWSIdlePingCandidate, 0)
+ p.accounts.Range(func(key, value any) bool {
+ accountID, ok := key.(int64)
+ if !ok || accountID <= 0 {
+ return true
+ }
+ ap, ok := value.(*openAIWSAccountPool)
+ if !ok || ap == nil {
+ return true
+ }
+ ap.mu.Lock()
+ for _, conn := range ap.conns {
+ if conn == nil || conn.isLeased() || conn.waiters.Load() > 0 {
+ continue
+ }
+ candidates = append(candidates, openAIWSIdlePingCandidate{
+ accountID: accountID,
+ conn: conn,
+ })
+ }
+ ap.mu.Unlock()
+ return true
+ })
+ return candidates
+}
+
+func (p *openAIWSConnPool) runBackgroundCleanupWorker() {
+ if p == nil {
+ return
+ }
+ ticker := time.NewTicker(openAIWSBackgroundSweepTicker)
+ defer ticker.Stop()
+ for {
+ select {
+ case <-ticker.C:
+ p.runBackgroundCleanupSweep(time.Now())
+ case <-p.workerStopCh:
+ return
+ }
+ }
+}
+
+func (p *openAIWSConnPool) runBackgroundCleanupSweep(now time.Time) {
+ if p == nil {
+ return
+ }
+ type cleanupResult struct {
+ evicted []*openAIWSConn
+ }
+ results := make([]cleanupResult, 0)
+ p.accounts.Range(func(_ any, value any) bool {
+ ap, ok := value.(*openAIWSAccountPool)
+ if !ok || ap == nil {
+ return true
+ }
+ maxConns := p.maxConnsHardCap()
+ ap.mu.Lock()
+ if ap.lastAcquire != nil && ap.lastAcquire.Account != nil {
+ maxConns = p.effectiveMaxConnsByAccount(ap.lastAcquire.Account)
+ }
+ evicted := p.cleanupAccountLocked(ap, now, maxConns)
+ ap.lastCleanupAt = now
+ ap.mu.Unlock()
+ if len(evicted) > 0 {
+ results = append(results, cleanupResult{evicted: evicted})
+ }
+ return true
+ })
+ for _, result := range results {
+ closeOpenAIWSConns(result.evicted)
+ }
+}
+
+func (p *openAIWSConnPool) Acquire(ctx context.Context, req openAIWSAcquireRequest) (*openAIWSConnLease, error) {
+ if p != nil {
+ p.metrics.acquireTotal.Add(1)
+ }
+ return p.acquire(ctx, cloneOpenAIWSAcquireRequest(req), 0)
+}
+
+func (p *openAIWSConnPool) acquire(ctx context.Context, req openAIWSAcquireRequest, retry int) (*openAIWSConnLease, error) {
+ if p == nil || req.Account == nil || req.Account.ID <= 0 {
+ return nil, errors.New("invalid ws acquire request")
+ }
+ if stringsTrim(req.WSURL) == "" {
+ return nil, errors.New("ws url is empty")
+ }
+
+ accountID := req.Account.ID
+ effectiveMaxConns := p.effectiveMaxConnsByAccount(req.Account)
+ if effectiveMaxConns <= 0 {
+ return nil, errOpenAIWSConnQueueFull
+ }
+ var evicted []*openAIWSConn
+ ap := p.getOrCreateAccountPool(accountID)
+ ap.mu.Lock()
+ ap.lastAcquire = cloneOpenAIWSAcquireRequestPtr(&req)
+ now := time.Now()
+ if ap.lastCleanupAt.IsZero() || now.Sub(ap.lastCleanupAt) >= openAIWSAcquireCleanupInterval {
+ evicted = p.cleanupAccountLocked(ap, now, effectiveMaxConns)
+ ap.lastCleanupAt = now
+ }
+ pickStartedAt := time.Now()
+ allowReuse := !req.ForceNewConn
+ preferredConnID := stringsTrim(req.PreferredConnID)
+ forcePreferredConn := allowReuse && req.ForcePreferredConn
+
+ if allowReuse {
+ if forcePreferredConn {
+ if preferredConnID == "" {
+ p.recordConnPickDuration(time.Since(pickStartedAt))
+ ap.mu.Unlock()
+ closeOpenAIWSConns(evicted)
+ return nil, errOpenAIWSPreferredConnUnavailable
+ }
+ preferredConn, ok := ap.conns[preferredConnID]
+ if !ok || preferredConn == nil {
+ p.recordConnPickDuration(time.Since(pickStartedAt))
+ ap.mu.Unlock()
+ closeOpenAIWSConns(evicted)
+ return nil, errOpenAIWSPreferredConnUnavailable
+ }
+ if preferredConn.tryAcquire() {
+ connPick := time.Since(pickStartedAt)
+ p.recordConnPickDuration(connPick)
+ ap.mu.Unlock()
+ closeOpenAIWSConns(evicted)
+ if p.shouldHealthCheckConn(preferredConn) {
+ if err := preferredConn.pingWithTimeout(openAIWSConnHealthCheckTO); err != nil {
+ preferredConn.close()
+ p.evictConn(accountID, preferredConn.id)
+ if retry < 1 {
+ return p.acquire(ctx, req, retry+1)
+ }
+ return nil, err
+ }
+ }
+ lease := &openAIWSConnLease{
+ pool: p,
+ accountID: accountID,
+ conn: preferredConn,
+ connPick: connPick,
+ reused: true,
+ }
+ p.metrics.acquireReuseTotal.Add(1)
+ p.ensureTargetIdleAsync(accountID)
+ return lease, nil
+ }
+
+ connPick := time.Since(pickStartedAt)
+ p.recordConnPickDuration(connPick)
+ if int(preferredConn.waiters.Load()) >= p.queueLimitPerConn() {
+ ap.mu.Unlock()
+ closeOpenAIWSConns(evicted)
+ return nil, errOpenAIWSConnQueueFull
+ }
+ preferredConn.waiters.Add(1)
+ ap.mu.Unlock()
+ closeOpenAIWSConns(evicted)
+ defer preferredConn.waiters.Add(-1)
+ waitStart := time.Now()
+ p.metrics.acquireQueueWaitTotal.Add(1)
+
+ if err := preferredConn.acquire(ctx); err != nil {
+ if errors.Is(err, errOpenAIWSConnClosed) && retry < 1 {
+ return p.acquire(ctx, req, retry+1)
+ }
+ return nil, err
+ }
+ if p.shouldHealthCheckConn(preferredConn) {
+ if err := preferredConn.pingWithTimeout(openAIWSConnHealthCheckTO); err != nil {
+ preferredConn.release()
+ preferredConn.close()
+ p.evictConn(accountID, preferredConn.id)
+ if retry < 1 {
+ return p.acquire(ctx, req, retry+1)
+ }
+ return nil, err
+ }
+ }
+
+ queueWait := time.Since(waitStart)
+ p.metrics.acquireQueueWaitMs.Add(queueWait.Milliseconds())
+ lease := &openAIWSConnLease{
+ pool: p,
+ accountID: accountID,
+ conn: preferredConn,
+ queueWait: queueWait,
+ connPick: connPick,
+ reused: true,
+ }
+ p.metrics.acquireReuseTotal.Add(1)
+ p.ensureTargetIdleAsync(accountID)
+ return lease, nil
+ }
+
+ if preferredConnID != "" {
+ if conn, ok := ap.conns[preferredConnID]; ok && conn.tryAcquire() {
+ connPick := time.Since(pickStartedAt)
+ p.recordConnPickDuration(connPick)
+ ap.mu.Unlock()
+ closeOpenAIWSConns(evicted)
+ if p.shouldHealthCheckConn(conn) {
+ if err := conn.pingWithTimeout(openAIWSConnHealthCheckTO); err != nil {
+ conn.close()
+ p.evictConn(accountID, conn.id)
+ if retry < 1 {
+ return p.acquire(ctx, req, retry+1)
+ }
+ return nil, err
+ }
+ }
+ lease := &openAIWSConnLease{pool: p, accountID: accountID, conn: conn, connPick: connPick, reused: true}
+ p.metrics.acquireReuseTotal.Add(1)
+ p.ensureTargetIdleAsync(accountID)
+ return lease, nil
+ }
+ }
+
+ best := p.pickLeastBusyConnLocked(ap, "")
+ if best != nil && best.tryAcquire() {
+ connPick := time.Since(pickStartedAt)
+ p.recordConnPickDuration(connPick)
+ ap.mu.Unlock()
+ closeOpenAIWSConns(evicted)
+ if p.shouldHealthCheckConn(best) {
+ if err := best.pingWithTimeout(openAIWSConnHealthCheckTO); err != nil {
+ best.close()
+ p.evictConn(accountID, best.id)
+ if retry < 1 {
+ return p.acquire(ctx, req, retry+1)
+ }
+ return nil, err
+ }
+ }
+ lease := &openAIWSConnLease{pool: p, accountID: accountID, conn: best, connPick: connPick, reused: true}
+ p.metrics.acquireReuseTotal.Add(1)
+ p.ensureTargetIdleAsync(accountID)
+ return lease, nil
+ }
+ for _, conn := range ap.conns {
+ if conn == nil || conn == best {
+ continue
+ }
+ if conn.tryAcquire() {
+ connPick := time.Since(pickStartedAt)
+ p.recordConnPickDuration(connPick)
+ ap.mu.Unlock()
+ closeOpenAIWSConns(evicted)
+ if p.shouldHealthCheckConn(conn) {
+ if err := conn.pingWithTimeout(openAIWSConnHealthCheckTO); err != nil {
+ conn.close()
+ p.evictConn(accountID, conn.id)
+ if retry < 1 {
+ return p.acquire(ctx, req, retry+1)
+ }
+ return nil, err
+ }
+ }
+ lease := &openAIWSConnLease{pool: p, accountID: accountID, conn: conn, connPick: connPick, reused: true}
+ p.metrics.acquireReuseTotal.Add(1)
+ p.ensureTargetIdleAsync(accountID)
+ return lease, nil
+ }
+ }
+ }
+
+ if req.ForceNewConn && len(ap.conns)+ap.creating >= effectiveMaxConns {
+ if idle := p.pickOldestIdleConnLocked(ap); idle != nil {
+ delete(ap.conns, idle.id)
+ evicted = append(evicted, idle)
+ p.metrics.scaleDownTotal.Add(1)
+ }
+ }
+
+ if len(ap.conns)+ap.creating < effectiveMaxConns {
+ connPick := time.Since(pickStartedAt)
+ p.recordConnPickDuration(connPick)
+ ap.creating++
+ ap.mu.Unlock()
+ closeOpenAIWSConns(evicted)
+
+ conn, dialErr := p.dialConn(ctx, req)
+
+ ap = p.getOrCreateAccountPool(accountID)
+ ap.mu.Lock()
+ ap.creating--
+ if dialErr != nil {
+ ap.prewarmFails++
+ ap.prewarmFailAt = time.Now()
+ ap.mu.Unlock()
+ return nil, dialErr
+ }
+ ap.conns[conn.id] = conn
+ ap.prewarmFails = 0
+ ap.prewarmFailAt = time.Time{}
+ ap.mu.Unlock()
+ p.metrics.acquireCreateTotal.Add(1)
+
+ if !conn.tryAcquire() {
+ if err := conn.acquire(ctx); err != nil {
+ conn.close()
+ p.evictConn(accountID, conn.id)
+ return nil, err
+ }
+ }
+ lease := &openAIWSConnLease{pool: p, accountID: accountID, conn: conn, connPick: connPick}
+ p.ensureTargetIdleAsync(accountID)
+ return lease, nil
+ }
+
+ if req.ForceNewConn {
+ p.recordConnPickDuration(time.Since(pickStartedAt))
+ ap.mu.Unlock()
+ closeOpenAIWSConns(evicted)
+ return nil, errOpenAIWSConnQueueFull
+ }
+
+ target := p.pickLeastBusyConnLocked(ap, req.PreferredConnID)
+ connPick := time.Since(pickStartedAt)
+ p.recordConnPickDuration(connPick)
+ if target == nil {
+ ap.mu.Unlock()
+ closeOpenAIWSConns(evicted)
+ return nil, errOpenAIWSConnClosed
+ }
+ if int(target.waiters.Load()) >= p.queueLimitPerConn() {
+ ap.mu.Unlock()
+ closeOpenAIWSConns(evicted)
+ return nil, errOpenAIWSConnQueueFull
+ }
+ target.waiters.Add(1)
+ ap.mu.Unlock()
+ closeOpenAIWSConns(evicted)
+ defer target.waiters.Add(-1)
+ waitStart := time.Now()
+ p.metrics.acquireQueueWaitTotal.Add(1)
+
+ if err := target.acquire(ctx); err != nil {
+ if errors.Is(err, errOpenAIWSConnClosed) && retry < 1 {
+ return p.acquire(ctx, req, retry+1)
+ }
+ return nil, err
+ }
+ if p.shouldHealthCheckConn(target) {
+ if err := target.pingWithTimeout(openAIWSConnHealthCheckTO); err != nil {
+ target.release()
+ target.close()
+ p.evictConn(accountID, target.id)
+ if retry < 1 {
+ return p.acquire(ctx, req, retry+1)
+ }
+ return nil, err
+ }
+ }
+
+ queueWait := time.Since(waitStart)
+ p.metrics.acquireQueueWaitMs.Add(queueWait.Milliseconds())
+ lease := &openAIWSConnLease{pool: p, accountID: accountID, conn: target, queueWait: queueWait, connPick: connPick, reused: true}
+ p.metrics.acquireReuseTotal.Add(1)
+ p.ensureTargetIdleAsync(accountID)
+ return lease, nil
+}
+
+func (p *openAIWSConnPool) recordConnPickDuration(duration time.Duration) {
+ if p == nil {
+ return
+ }
+ if duration < 0 {
+ duration = 0
+ }
+ p.metrics.connPickTotal.Add(1)
+ p.metrics.connPickMs.Add(duration.Milliseconds())
+}
+
+func (p *openAIWSConnPool) pickOldestIdleConnLocked(ap *openAIWSAccountPool) *openAIWSConn {
+ if ap == nil || len(ap.conns) == 0 {
+ return nil
+ }
+ var oldest *openAIWSConn
+ for _, conn := range ap.conns {
+ if conn == nil || conn.isLeased() || conn.waiters.Load() > 0 || p.isConnPinnedLocked(ap, conn.id) {
+ continue
+ }
+ if oldest == nil || conn.lastUsedAt().Before(oldest.lastUsedAt()) {
+ oldest = conn
+ }
+ }
+ return oldest
+}
+
+func (p *openAIWSConnPool) getOrCreateAccountPool(accountID int64) *openAIWSAccountPool {
+ if p == nil || accountID <= 0 {
+ return nil
+ }
+ if existing, ok := p.accounts.Load(accountID); ok {
+ if ap, typed := existing.(*openAIWSAccountPool); typed && ap != nil {
+ return ap
+ }
+ }
+ ap := &openAIWSAccountPool{
+ conns: make(map[string]*openAIWSConn),
+ pinnedConns: make(map[string]int),
+ }
+ actual, _ := p.accounts.LoadOrStore(accountID, ap)
+ if typed, ok := actual.(*openAIWSAccountPool); ok && typed != nil {
+ return typed
+ }
+ return ap
+}
+
+// ensureAccountPoolLocked 兼容旧调用。
+func (p *openAIWSConnPool) ensureAccountPoolLocked(accountID int64) *openAIWSAccountPool {
+ return p.getOrCreateAccountPool(accountID)
+}
+
+func (p *openAIWSConnPool) getAccountPool(accountID int64) (*openAIWSAccountPool, bool) {
+ if p == nil || accountID <= 0 {
+ return nil, false
+ }
+ value, ok := p.accounts.Load(accountID)
+ if !ok || value == nil {
+ return nil, false
+ }
+ ap, typed := value.(*openAIWSAccountPool)
+ return ap, typed && ap != nil
+}
+
+func (p *openAIWSConnPool) isConnPinnedLocked(ap *openAIWSAccountPool, connID string) bool {
+ if ap == nil || connID == "" || len(ap.pinnedConns) == 0 {
+ return false
+ }
+ return ap.pinnedConns[connID] > 0
+}
+
+func (p *openAIWSConnPool) cleanupAccountLocked(ap *openAIWSAccountPool, now time.Time, maxConns int) []*openAIWSConn {
+ if ap == nil {
+ return nil
+ }
+ maxAge := p.maxConnAge()
+
+ evicted := make([]*openAIWSConn, 0)
+ for id, conn := range ap.conns {
+ if conn == nil {
+ delete(ap.conns, id)
+ if len(ap.pinnedConns) > 0 {
+ delete(ap.pinnedConns, id)
+ }
+ continue
+ }
+ select {
+ case <-conn.closedCh:
+ delete(ap.conns, id)
+ if len(ap.pinnedConns) > 0 {
+ delete(ap.pinnedConns, id)
+ }
+ evicted = append(evicted, conn)
+ continue
+ default:
+ }
+ if p.isConnPinnedLocked(ap, id) {
+ continue
+ }
+ if maxAge > 0 && !conn.isLeased() && conn.age(now) > maxAge {
+ delete(ap.conns, id)
+ if len(ap.pinnedConns) > 0 {
+ delete(ap.pinnedConns, id)
+ }
+ evicted = append(evicted, conn)
+ }
+ }
+
+ if maxConns <= 0 {
+ maxConns = p.maxConnsHardCap()
+ }
+ maxIdle := p.maxIdlePerAccount()
+ if maxIdle < 0 || maxIdle > maxConns {
+ maxIdle = maxConns
+ }
+ if maxIdle >= 0 && len(ap.conns) > maxIdle {
+ idleConns := make([]*openAIWSConn, 0, len(ap.conns))
+ for id, conn := range ap.conns {
+ if conn == nil {
+ delete(ap.conns, id)
+ if len(ap.pinnedConns) > 0 {
+ delete(ap.pinnedConns, id)
+ }
+ continue
+ }
+ // 有等待者的连接不能在清理阶段被淘汰,否则等待中的 acquire 会收到 closed 错误。
+ if conn.isLeased() || conn.waiters.Load() > 0 || p.isConnPinnedLocked(ap, conn.id) {
+ continue
+ }
+ idleConns = append(idleConns, conn)
+ }
+ sort.SliceStable(idleConns, func(i, j int) bool {
+ return idleConns[i].lastUsedAt().Before(idleConns[j].lastUsedAt())
+ })
+ redundant := len(ap.conns) - maxIdle
+ if redundant > len(idleConns) {
+ redundant = len(idleConns)
+ }
+ for i := 0; i < redundant; i++ {
+ conn := idleConns[i]
+ delete(ap.conns, conn.id)
+ if len(ap.pinnedConns) > 0 {
+ delete(ap.pinnedConns, conn.id)
+ }
+ evicted = append(evicted, conn)
+ }
+ if redundant > 0 {
+ p.metrics.scaleDownTotal.Add(int64(redundant))
+ }
+ }
+
+ return evicted
+}
+
+func (p *openAIWSConnPool) pickLeastBusyConnLocked(ap *openAIWSAccountPool, preferredConnID string) *openAIWSConn {
+ if ap == nil || len(ap.conns) == 0 {
+ return nil
+ }
+ preferredConnID = stringsTrim(preferredConnID)
+ if preferredConnID != "" {
+ if conn, ok := ap.conns[preferredConnID]; ok {
+ return conn
+ }
+ }
+ var best *openAIWSConn
+ var bestWaiters int32
+ var bestLastUsed time.Time
+ for _, conn := range ap.conns {
+ if conn == nil {
+ continue
+ }
+ waiters := conn.waiters.Load()
+ lastUsed := conn.lastUsedAt()
+ if best == nil ||
+ waiters < bestWaiters ||
+ (waiters == bestWaiters && lastUsed.Before(bestLastUsed)) {
+ best = conn
+ bestWaiters = waiters
+ bestLastUsed = lastUsed
+ }
+ }
+ return best
+}
+
+func accountPoolLoadLocked(ap *openAIWSAccountPool) (inflight int, waiters int) {
+ if ap == nil {
+ return 0, 0
+ }
+ for _, conn := range ap.conns {
+ if conn == nil {
+ continue
+ }
+ if conn.isLeased() {
+ inflight++
+ }
+ waiters += int(conn.waiters.Load())
+ }
+ return inflight, waiters
+}
+
+// AccountPoolLoad 返回指定账号连接池的并发与排队快照。
+func (p *openAIWSConnPool) AccountPoolLoad(accountID int64) (inflight int, waiters int, conns int) {
+ if p == nil || accountID <= 0 {
+ return 0, 0, 0
+ }
+ ap, ok := p.getAccountPool(accountID)
+ if !ok || ap == nil {
+ return 0, 0, 0
+ }
+ ap.mu.Lock()
+ defer ap.mu.Unlock()
+ inflight, waiters = accountPoolLoadLocked(ap)
+ return inflight, waiters, len(ap.conns)
+}
+
+func (p *openAIWSConnPool) ensureTargetIdleAsync(accountID int64) {
+ if p == nil || accountID <= 0 {
+ return
+ }
+
+ var req openAIWSAcquireRequest
+ need := 0
+ ap, ok := p.getAccountPool(accountID)
+ if !ok || ap == nil {
+ return
+ }
+ ap.mu.Lock()
+ defer ap.mu.Unlock()
+ if ap.lastAcquire == nil {
+ return
+ }
+ if ap.prewarmActive {
+ return
+ }
+ now := time.Now()
+ if !ap.prewarmUntil.IsZero() && now.Before(ap.prewarmUntil) {
+ return
+ }
+ if p.shouldSuppressPrewarmLocked(ap, now) {
+ return
+ }
+ effectiveMaxConns := p.maxConnsHardCap()
+ if ap.lastAcquire != nil && ap.lastAcquire.Account != nil {
+ effectiveMaxConns = p.effectiveMaxConnsByAccount(ap.lastAcquire.Account)
+ }
+ target := p.targetConnCountLocked(ap, effectiveMaxConns)
+ current := len(ap.conns) + ap.creating
+ if current >= target {
+ return
+ }
+ need = target - current
+ if need <= 0 {
+ return
+ }
+ req = cloneOpenAIWSAcquireRequest(*ap.lastAcquire)
+ ap.prewarmActive = true
+ if cooldown := p.prewarmCooldown(); cooldown > 0 {
+ ap.prewarmUntil = now.Add(cooldown)
+ }
+ ap.creating += need
+ p.metrics.scaleUpTotal.Add(int64(need))
+
+ go p.prewarmConns(accountID, req, need)
+}
+
+func (p *openAIWSConnPool) targetConnCountLocked(ap *openAIWSAccountPool, maxConns int) int {
+ if ap == nil {
+ return 0
+ }
+
+ if maxConns <= 0 {
+ return 0
+ }
+
+ minIdle := p.minIdlePerAccount()
+ if minIdle < 0 {
+ minIdle = 0
+ }
+ if minIdle > maxConns {
+ minIdle = maxConns
+ }
+
+ inflight, waiters := accountPoolLoadLocked(ap)
+ utilization := p.targetUtilization()
+ demand := inflight + waiters
+ if demand <= 0 {
+ return minIdle
+ }
+
+ target := 1
+ if demand > 1 {
+ target = int(math.Ceil(float64(demand) / utilization))
+ }
+ if waiters > 0 && target < len(ap.conns)+1 {
+ target = len(ap.conns) + 1
+ }
+ if target < minIdle {
+ target = minIdle
+ }
+ if target > maxConns {
+ target = maxConns
+ }
+ return target
+}
+
+func (p *openAIWSConnPool) prewarmConns(accountID int64, req openAIWSAcquireRequest, total int) {
+ defer func() {
+ if ap, ok := p.getAccountPool(accountID); ok && ap != nil {
+ ap.mu.Lock()
+ ap.prewarmActive = false
+ ap.mu.Unlock()
+ }
+ }()
+
+ for i := 0; i < total; i++ {
+ ctx, cancel := context.WithTimeout(context.Background(), p.dialTimeout()+openAIWSConnPrewarmExtraDelay)
+ conn, err := p.dialConn(ctx, req)
+ cancel()
+
+ ap, ok := p.getAccountPool(accountID)
+ if !ok || ap == nil {
+ if conn != nil {
+ conn.close()
+ }
+ return
+ }
+ ap.mu.Lock()
+ if ap.creating > 0 {
+ ap.creating--
+ }
+ if err != nil {
+ ap.prewarmFails++
+ ap.prewarmFailAt = time.Now()
+ ap.mu.Unlock()
+ continue
+ }
+ if len(ap.conns) >= p.effectiveMaxConnsByAccount(req.Account) {
+ ap.mu.Unlock()
+ conn.close()
+ continue
+ }
+ ap.conns[conn.id] = conn
+ ap.prewarmFails = 0
+ ap.prewarmFailAt = time.Time{}
+ ap.mu.Unlock()
+ }
+}
+
+func (p *openAIWSConnPool) evictConn(accountID int64, connID string) {
+ if p == nil || accountID <= 0 || stringsTrim(connID) == "" {
+ return
+ }
+ var conn *openAIWSConn
+ ap, ok := p.getAccountPool(accountID)
+ if ok && ap != nil {
+ ap.mu.Lock()
+ if c, exists := ap.conns[connID]; exists {
+ conn = c
+ delete(ap.conns, connID)
+ if len(ap.pinnedConns) > 0 {
+ delete(ap.pinnedConns, connID)
+ }
+ }
+ ap.mu.Unlock()
+ }
+ if conn != nil {
+ conn.close()
+ }
+}
+
+func (p *openAIWSConnPool) PinConn(accountID int64, connID string) bool {
+ if p == nil || accountID <= 0 {
+ return false
+ }
+ connID = stringsTrim(connID)
+ if connID == "" {
+ return false
+ }
+ ap, ok := p.getAccountPool(accountID)
+ if !ok || ap == nil {
+ return false
+ }
+ ap.mu.Lock()
+ defer ap.mu.Unlock()
+ if _, exists := ap.conns[connID]; !exists {
+ return false
+ }
+ if ap.pinnedConns == nil {
+ ap.pinnedConns = make(map[string]int)
+ }
+ ap.pinnedConns[connID]++
+ return true
+}
+
+func (p *openAIWSConnPool) UnpinConn(accountID int64, connID string) {
+ if p == nil || accountID <= 0 {
+ return
+ }
+ connID = stringsTrim(connID)
+ if connID == "" {
+ return
+ }
+ ap, ok := p.getAccountPool(accountID)
+ if !ok || ap == nil {
+ return
+ }
+ ap.mu.Lock()
+ defer ap.mu.Unlock()
+ if len(ap.pinnedConns) == 0 {
+ return
+ }
+ count := ap.pinnedConns[connID]
+ if count <= 1 {
+ delete(ap.pinnedConns, connID)
+ return
+ }
+ ap.pinnedConns[connID] = count - 1
+}
+
+func (p *openAIWSConnPool) dialConn(ctx context.Context, req openAIWSAcquireRequest) (*openAIWSConn, error) {
+ if p == nil || p.clientDialer == nil {
+ return nil, errors.New("openai ws client dialer is nil")
+ }
+ conn, status, handshakeHeaders, err := p.clientDialer.Dial(ctx, req.WSURL, req.Headers, req.ProxyURL)
+ if err != nil {
+ return nil, &openAIWSDialError{
+ StatusCode: status,
+ ResponseHeaders: cloneHeader(handshakeHeaders),
+ Err: err,
+ }
+ }
+ if conn == nil {
+ return nil, &openAIWSDialError{
+ StatusCode: status,
+ ResponseHeaders: cloneHeader(handshakeHeaders),
+ Err: errors.New("openai ws dialer returned nil connection"),
+ }
+ }
+ id := p.nextConnID(req.Account.ID)
+ return newOpenAIWSConn(id, req.Account.ID, conn, handshakeHeaders), nil
+}
+
+func (p *openAIWSConnPool) nextConnID(accountID int64) string {
+ seq := p.seq.Add(1)
+ buf := make([]byte, 0, 32)
+ buf = append(buf, "oa_ws_"...)
+ buf = strconv.AppendInt(buf, accountID, 10)
+ buf = append(buf, '_')
+ buf = strconv.AppendUint(buf, seq, 10)
+ return string(buf)
+}
+
+func (p *openAIWSConnPool) shouldHealthCheckConn(conn *openAIWSConn) bool {
+ if conn == nil {
+ return false
+ }
+ return conn.idleDuration(time.Now()) >= openAIWSConnHealthCheckIdle
+}
+
+func (p *openAIWSConnPool) maxConnsHardCap() int {
+ if p != nil && p.cfg != nil && p.cfg.Gateway.OpenAIWS.MaxConnsPerAccount > 0 {
+ return p.cfg.Gateway.OpenAIWS.MaxConnsPerAccount
+ }
+ return 8
+}
+
+func (p *openAIWSConnPool) dynamicMaxConnsEnabled() bool {
+ if p != nil && p.cfg != nil {
+ return p.cfg.Gateway.OpenAIWS.DynamicMaxConnsByAccountConcurrencyEnabled
+ }
+ return false
+}
+
+func (p *openAIWSConnPool) modeRouterV2Enabled() bool {
+ if p != nil && p.cfg != nil {
+ return p.cfg.Gateway.OpenAIWS.ModeRouterV2Enabled
+ }
+ return false
+}
+
+func (p *openAIWSConnPool) maxConnsFactorByAccount(account *Account) float64 {
+ if p == nil || p.cfg == nil || account == nil {
+ return 1.0
+ }
+ switch account.Type {
+ case AccountTypeOAuth:
+ if p.cfg.Gateway.OpenAIWS.OAuthMaxConnsFactor > 0 {
+ return p.cfg.Gateway.OpenAIWS.OAuthMaxConnsFactor
+ }
+ case AccountTypeAPIKey:
+ if p.cfg.Gateway.OpenAIWS.APIKeyMaxConnsFactor > 0 {
+ return p.cfg.Gateway.OpenAIWS.APIKeyMaxConnsFactor
+ }
+ }
+ return 1.0
+}
+
+func (p *openAIWSConnPool) effectiveMaxConnsByAccount(account *Account) int {
+ hardCap := p.maxConnsHardCap()
+ if hardCap <= 0 {
+ return 0
+ }
+ if p.modeRouterV2Enabled() {
+ if account == nil {
+ return hardCap
+ }
+ if account.Concurrency <= 0 {
+ return 0
+ }
+ return account.Concurrency
+ }
+ if account == nil || !p.dynamicMaxConnsEnabled() {
+ return hardCap
+ }
+ if account.Concurrency <= 0 {
+ // 0/-1 等“无限制”并发场景下,仍由全局硬上限兜底。
+ return hardCap
+ }
+ factor := p.maxConnsFactorByAccount(account)
+ if factor <= 0 {
+ factor = 1.0
+ }
+ effective := int(math.Ceil(float64(account.Concurrency) * factor))
+ if effective < 1 {
+ effective = 1
+ }
+ if effective > hardCap {
+ effective = hardCap
+ }
+ return effective
+}
+
+func (p *openAIWSConnPool) minIdlePerAccount() int {
+ if p != nil && p.cfg != nil && p.cfg.Gateway.OpenAIWS.MinIdlePerAccount >= 0 {
+ return p.cfg.Gateway.OpenAIWS.MinIdlePerAccount
+ }
+ return 0
+}
+
+func (p *openAIWSConnPool) maxIdlePerAccount() int {
+ if p != nil && p.cfg != nil && p.cfg.Gateway.OpenAIWS.MaxIdlePerAccount >= 0 {
+ return p.cfg.Gateway.OpenAIWS.MaxIdlePerAccount
+ }
+ return 4
+}
+
+func (p *openAIWSConnPool) maxConnAge() time.Duration {
+ return openAIWSConnMaxAge
+}
+
+func (p *openAIWSConnPool) queueLimitPerConn() int {
+ if p != nil && p.cfg != nil && p.cfg.Gateway.OpenAIWS.QueueLimitPerConn > 0 {
+ return p.cfg.Gateway.OpenAIWS.QueueLimitPerConn
+ }
+ return 256
+}
+
+func (p *openAIWSConnPool) targetUtilization() float64 {
+ if p != nil && p.cfg != nil {
+ ratio := p.cfg.Gateway.OpenAIWS.PoolTargetUtilization
+ if ratio > 0 && ratio <= 1 {
+ return ratio
+ }
+ }
+ return 0.7
+}
+
+func (p *openAIWSConnPool) prewarmCooldown() time.Duration {
+ if p != nil && p.cfg != nil && p.cfg.Gateway.OpenAIWS.PrewarmCooldownMS > 0 {
+ return time.Duration(p.cfg.Gateway.OpenAIWS.PrewarmCooldownMS) * time.Millisecond
+ }
+ return 0
+}
+
+func (p *openAIWSConnPool) shouldSuppressPrewarmLocked(ap *openAIWSAccountPool, now time.Time) bool {
+ if ap == nil {
+ return true
+ }
+ if ap.prewarmFails <= 0 {
+ return false
+ }
+ if ap.prewarmFailAt.IsZero() {
+ ap.prewarmFails = 0
+ return false
+ }
+ if now.Sub(ap.prewarmFailAt) > openAIWSPrewarmFailureWindow {
+ ap.prewarmFails = 0
+ ap.prewarmFailAt = time.Time{}
+ return false
+ }
+ return ap.prewarmFails >= openAIWSPrewarmFailureSuppress
+}
+
+func (p *openAIWSConnPool) dialTimeout() time.Duration {
+ if p != nil && p.cfg != nil && p.cfg.Gateway.OpenAIWS.DialTimeoutSeconds > 0 {
+ return time.Duration(p.cfg.Gateway.OpenAIWS.DialTimeoutSeconds) * time.Second
+ }
+ return 10 * time.Second
+}
+
+func cloneOpenAIWSAcquireRequest(req openAIWSAcquireRequest) openAIWSAcquireRequest {
+ copied := req
+ copied.Headers = cloneHeader(req.Headers)
+ copied.WSURL = stringsTrim(req.WSURL)
+ copied.ProxyURL = stringsTrim(req.ProxyURL)
+ copied.PreferredConnID = stringsTrim(req.PreferredConnID)
+ return copied
+}
+
+func cloneOpenAIWSAcquireRequestPtr(req *openAIWSAcquireRequest) *openAIWSAcquireRequest {
+ if req == nil {
+ return nil
+ }
+ copied := cloneOpenAIWSAcquireRequest(*req)
+ return &copied
+}
+
+func cloneHeader(src http.Header) http.Header {
+ if src == nil {
+ return nil
+ }
+ dst := make(http.Header, len(src))
+ for k, vals := range src {
+ if len(vals) == 0 {
+ dst[k] = nil
+ continue
+ }
+ copied := make([]string, len(vals))
+ copy(copied, vals)
+ dst[k] = copied
+ }
+ return dst
+}
+
+func closeOpenAIWSConns(conns []*openAIWSConn) {
+ if len(conns) == 0 {
+ return
+ }
+ for _, conn := range conns {
+ if conn == nil {
+ continue
+ }
+ conn.close()
+ }
+}
+
+func stringsTrim(value string) string {
+ return strings.TrimSpace(value)
+}
diff --git a/backend/internal/service/openai_ws_pool_benchmark_test.go b/backend/internal/service/openai_ws_pool_benchmark_test.go
new file mode 100644
index 00000000..bff74b62
--- /dev/null
+++ b/backend/internal/service/openai_ws_pool_benchmark_test.go
@@ -0,0 +1,58 @@
+package service
+
+import (
+ "context"
+ "errors"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+)
+
+func BenchmarkOpenAIWSPoolAcquire(b *testing.B) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 8
+ cfg.Gateway.OpenAIWS.MinIdlePerAccount = 1
+ cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 4
+ cfg.Gateway.OpenAIWS.QueueLimitPerConn = 256
+ cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 1
+
+ pool := newOpenAIWSConnPool(cfg)
+ pool.setClientDialerForTest(&openAIWSCountingDialer{})
+
+ account := &Account{ID: 1001, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
+ req := openAIWSAcquireRequest{
+ Account: account,
+ WSURL: "wss://example.com/v1/responses",
+ }
+ ctx := context.Background()
+
+ lease, err := pool.Acquire(ctx, req)
+ if err != nil {
+ b.Fatalf("warm acquire failed: %v", err)
+ }
+ lease.Release()
+
+ b.ReportAllocs()
+ b.ResetTimer()
+ b.RunParallel(func(pb *testing.PB) {
+ for pb.Next() {
+ var (
+ got *openAIWSConnLease
+ acquireErr error
+ )
+ for retry := 0; retry < 3; retry++ {
+ got, acquireErr = pool.Acquire(ctx, req)
+ if acquireErr == nil {
+ break
+ }
+ if !errors.Is(acquireErr, errOpenAIWSConnClosed) {
+ break
+ }
+ }
+ if acquireErr != nil {
+ b.Fatalf("acquire failed: %v", acquireErr)
+ }
+ got.Release()
+ }
+ })
+}
diff --git a/backend/internal/service/openai_ws_pool_test.go b/backend/internal/service/openai_ws_pool_test.go
new file mode 100644
index 00000000..b2683ee0
--- /dev/null
+++ b/backend/internal/service/openai_ws_pool_test.go
@@ -0,0 +1,1709 @@
+package service
+
+import (
+ "context"
+ "errors"
+ "net/http"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/stretchr/testify/require"
+)
+
+func TestOpenAIWSConnPool_CleanupStaleAndTrimIdle(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
+ pool := newOpenAIWSConnPool(cfg)
+
+ accountID := int64(10)
+ ap := pool.getOrCreateAccountPool(accountID)
+
+ stale := newOpenAIWSConn("stale", accountID, nil, nil)
+ stale.createdAtNano.Store(time.Now().Add(-2 * time.Hour).UnixNano())
+ stale.lastUsedNano.Store(time.Now().Add(-2 * time.Hour).UnixNano())
+
+ idleOld := newOpenAIWSConn("idle_old", accountID, nil, nil)
+ idleOld.lastUsedNano.Store(time.Now().Add(-10 * time.Minute).UnixNano())
+
+ idleNew := newOpenAIWSConn("idle_new", accountID, nil, nil)
+ idleNew.lastUsedNano.Store(time.Now().Add(-1 * time.Minute).UnixNano())
+
+ ap.conns[stale.id] = stale
+ ap.conns[idleOld.id] = idleOld
+ ap.conns[idleNew.id] = idleNew
+
+ evicted := pool.cleanupAccountLocked(ap, time.Now(), pool.maxConnsHardCap())
+ closeOpenAIWSConns(evicted)
+
+ require.Nil(t, ap.conns["stale"], "stale connection should be rotated")
+ require.Nil(t, ap.conns["idle_old"], "old idle should be trimmed by max_idle")
+ require.NotNil(t, ap.conns["idle_new"], "newer idle should be kept")
+}
+
+func TestOpenAIWSConnPool_NextConnIDFormat(t *testing.T) {
+ pool := newOpenAIWSConnPool(&config.Config{})
+ id1 := pool.nextConnID(42)
+ id2 := pool.nextConnID(42)
+
+ require.True(t, strings.HasPrefix(id1, "oa_ws_42_"))
+ require.True(t, strings.HasPrefix(id2, "oa_ws_42_"))
+ require.NotEqual(t, id1, id2)
+ require.Equal(t, "oa_ws_42_1", id1)
+ require.Equal(t, "oa_ws_42_2", id2)
+}
+
+func TestOpenAIWSConnPool_AcquireCleanupInterval(t *testing.T) {
+ require.Equal(t, 3*time.Second, openAIWSAcquireCleanupInterval)
+ require.Less(t, openAIWSAcquireCleanupInterval, openAIWSBackgroundSweepTicker)
+}
+
+func TestOpenAIWSConnLease_WriteJSONAndGuards(t *testing.T) {
+ conn := newOpenAIWSConn("lease_write", 1, &openAIWSFakeConn{}, nil)
+ lease := &openAIWSConnLease{conn: conn}
+ require.NoError(t, lease.WriteJSON(map[string]any{"type": "response.create"}, 0))
+
+ var nilLease *openAIWSConnLease
+ err := nilLease.WriteJSONWithContextTimeout(context.Background(), map[string]any{"type": "response.create"}, time.Second)
+ require.ErrorIs(t, err, errOpenAIWSConnClosed)
+
+ err = (&openAIWSConnLease{}).WriteJSONWithContextTimeout(context.Background(), map[string]any{"type": "response.create"}, time.Second)
+ require.ErrorIs(t, err, errOpenAIWSConnClosed)
+}
+
+func TestOpenAIWSConn_WriteJSONWithTimeout_NilParentContextUsesBackground(t *testing.T) {
+ probe := &openAIWSContextProbeConn{}
+ conn := newOpenAIWSConn("ctx_probe", 1, probe, nil)
+ require.NoError(t, conn.writeJSONWithTimeout(context.Background(), map[string]any{"type": "response.create"}, 0))
+ require.NotNil(t, probe.lastWriteCtx)
+}
+
+func TestOpenAIWSConnPool_TargetConnCountAdaptive(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 6
+ cfg.Gateway.OpenAIWS.MinIdlePerAccount = 1
+ cfg.Gateway.OpenAIWS.PoolTargetUtilization = 0.5
+
+ pool := newOpenAIWSConnPool(cfg)
+ ap := pool.getOrCreateAccountPool(88)
+
+ conn1 := newOpenAIWSConn("c1", 88, nil, nil)
+ conn2 := newOpenAIWSConn("c2", 88, nil, nil)
+ require.True(t, conn1.tryAcquire())
+ require.True(t, conn2.tryAcquire())
+ conn1.waiters.Store(1)
+ conn2.waiters.Store(1)
+
+ ap.conns[conn1.id] = conn1
+ ap.conns[conn2.id] = conn2
+
+ target := pool.targetConnCountLocked(ap, pool.maxConnsHardCap())
+ require.Equal(t, 6, target, "应按 inflight+waiters 与 target_utilization 自适应扩容到上限")
+
+ conn1.release()
+ conn2.release()
+ conn1.waiters.Store(0)
+ conn2.waiters.Store(0)
+ target = pool.targetConnCountLocked(ap, pool.maxConnsHardCap())
+ require.Equal(t, 1, target, "低负载时应缩回到最小空闲连接")
+}
+
+func TestOpenAIWSConnPool_TargetConnCountMinIdleZero(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 4
+ cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
+ cfg.Gateway.OpenAIWS.PoolTargetUtilization = 0.8
+
+ pool := newOpenAIWSConnPool(cfg)
+ ap := pool.getOrCreateAccountPool(66)
+
+ target := pool.targetConnCountLocked(ap, pool.maxConnsHardCap())
+ require.Equal(t, 0, target, "min_idle=0 且无负载时应允许缩容到 0")
+}
+
+func TestOpenAIWSConnPool_EnsureTargetIdleAsync(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 4
+ cfg.Gateway.OpenAIWS.MinIdlePerAccount = 2
+ cfg.Gateway.OpenAIWS.PoolTargetUtilization = 0.8
+ cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 1
+
+ pool := newOpenAIWSConnPool(cfg)
+ pool.setClientDialerForTest(&openAIWSFakeDialer{})
+
+ accountID := int64(77)
+ account := &Account{ID: accountID, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
+ ap := pool.getOrCreateAccountPool(accountID)
+ ap.mu.Lock()
+ ap.lastAcquire = &openAIWSAcquireRequest{
+ Account: account,
+ WSURL: "wss://example.com/v1/responses",
+ }
+ ap.mu.Unlock()
+
+ pool.ensureTargetIdleAsync(accountID)
+
+ require.Eventually(t, func() bool {
+ ap, ok := pool.getAccountPool(accountID)
+ if !ok || ap == nil {
+ return false
+ }
+ ap.mu.Lock()
+ defer ap.mu.Unlock()
+ return len(ap.conns) >= 2
+ }, 2*time.Second, 20*time.Millisecond)
+
+ metrics := pool.SnapshotMetrics()
+ require.GreaterOrEqual(t, metrics.ScaleUpTotal, int64(2))
+}
+
+func TestOpenAIWSConnPool_EnsureTargetIdleAsyncCooldown(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 4
+ cfg.Gateway.OpenAIWS.MinIdlePerAccount = 2
+ cfg.Gateway.OpenAIWS.PoolTargetUtilization = 0.8
+ cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 1
+ cfg.Gateway.OpenAIWS.PrewarmCooldownMS = 500
+
+ pool := newOpenAIWSConnPool(cfg)
+ dialer := &openAIWSCountingDialer{}
+ pool.setClientDialerForTest(dialer)
+
+ accountID := int64(178)
+ account := &Account{ID: accountID, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
+ ap := pool.getOrCreateAccountPool(accountID)
+ ap.mu.Lock()
+ ap.lastAcquire = &openAIWSAcquireRequest{
+ Account: account,
+ WSURL: "wss://example.com/v1/responses",
+ }
+ ap.mu.Unlock()
+
+ pool.ensureTargetIdleAsync(accountID)
+ require.Eventually(t, func() bool {
+ ap, ok := pool.getAccountPool(accountID)
+ if !ok || ap == nil {
+ return false
+ }
+ ap.mu.Lock()
+ defer ap.mu.Unlock()
+ return len(ap.conns) >= 2 && !ap.prewarmActive
+ }, 2*time.Second, 20*time.Millisecond)
+ firstDialCount := dialer.DialCount()
+ require.GreaterOrEqual(t, firstDialCount, 2)
+
+ // 人工制造缺口触发新一轮预热需求。
+ ap, ok := pool.getAccountPool(accountID)
+ require.True(t, ok)
+ require.NotNil(t, ap)
+ ap.mu.Lock()
+ for id := range ap.conns {
+ delete(ap.conns, id)
+ break
+ }
+ ap.mu.Unlock()
+
+ pool.ensureTargetIdleAsync(accountID)
+ time.Sleep(120 * time.Millisecond)
+ require.Equal(t, firstDialCount, dialer.DialCount(), "cooldown 窗口内不应再次触发预热")
+
+ time.Sleep(450 * time.Millisecond)
+ pool.ensureTargetIdleAsync(accountID)
+ require.Eventually(t, func() bool {
+ return dialer.DialCount() > firstDialCount
+ }, 2*time.Second, 20*time.Millisecond)
+}
+
+func TestOpenAIWSConnPool_EnsureTargetIdleAsyncFailureSuppress(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2
+ cfg.Gateway.OpenAIWS.MinIdlePerAccount = 1
+ cfg.Gateway.OpenAIWS.PoolTargetUtilization = 0.8
+ cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 1
+ cfg.Gateway.OpenAIWS.PrewarmCooldownMS = 0
+
+ pool := newOpenAIWSConnPool(cfg)
+ dialer := &openAIWSAlwaysFailDialer{}
+ pool.setClientDialerForTest(dialer)
+
+ accountID := int64(279)
+ account := &Account{ID: accountID, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
+ ap := pool.getOrCreateAccountPool(accountID)
+ ap.mu.Lock()
+ ap.lastAcquire = &openAIWSAcquireRequest{
+ Account: account,
+ WSURL: "wss://example.com/v1/responses",
+ }
+ ap.mu.Unlock()
+
+ pool.ensureTargetIdleAsync(accountID)
+ require.Eventually(t, func() bool {
+ ap, ok := pool.getAccountPool(accountID)
+ if !ok || ap == nil {
+ return false
+ }
+ ap.mu.Lock()
+ defer ap.mu.Unlock()
+ return !ap.prewarmActive
+ }, 2*time.Second, 20*time.Millisecond)
+
+ pool.ensureTargetIdleAsync(accountID)
+ require.Eventually(t, func() bool {
+ ap, ok := pool.getAccountPool(accountID)
+ if !ok || ap == nil {
+ return false
+ }
+ ap.mu.Lock()
+ defer ap.mu.Unlock()
+ return !ap.prewarmActive
+ }, 2*time.Second, 20*time.Millisecond)
+ require.Equal(t, 2, dialer.DialCount())
+
+ // 连续失败达到阈值后,新的预热触发应被抑制,不再继续拨号。
+ pool.ensureTargetIdleAsync(accountID)
+ time.Sleep(120 * time.Millisecond)
+ require.Equal(t, 2, dialer.DialCount())
+}
+
+func TestOpenAIWSConnPool_AcquireQueueWaitMetrics(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
+ cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
+ cfg.Gateway.OpenAIWS.QueueLimitPerConn = 4
+
+ pool := newOpenAIWSConnPool(cfg)
+ accountID := int64(99)
+ account := &Account{ID: accountID, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
+ conn := newOpenAIWSConn("busy", accountID, &openAIWSFakeConn{}, nil)
+ require.True(t, conn.tryAcquire()) // 占用连接,触发后续排队
+
+ ap := pool.ensureAccountPoolLocked(accountID)
+ ap.mu.Lock()
+ ap.conns[conn.id] = conn
+ ap.lastAcquire = &openAIWSAcquireRequest{
+ Account: account,
+ WSURL: "wss://example.com/v1/responses",
+ }
+ ap.mu.Unlock()
+
+ go func() {
+ time.Sleep(60 * time.Millisecond)
+ conn.release()
+ }()
+
+ lease, err := pool.Acquire(context.Background(), openAIWSAcquireRequest{
+ Account: account,
+ WSURL: "wss://example.com/v1/responses",
+ })
+ require.NoError(t, err)
+ require.NotNil(t, lease)
+ require.True(t, lease.Reused())
+ require.GreaterOrEqual(t, lease.QueueWaitDuration(), 50*time.Millisecond)
+ lease.Release()
+
+ metrics := pool.SnapshotMetrics()
+ require.GreaterOrEqual(t, metrics.AcquireQueueWaitTotal, int64(1))
+ require.Greater(t, metrics.AcquireQueueWaitMsTotal, int64(0))
+ require.GreaterOrEqual(t, metrics.ConnPickTotal, int64(1))
+}
+
+func TestOpenAIWSConnPool_ForceNewConnSkipsReuse(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2
+ cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
+ cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 2
+
+ pool := newOpenAIWSConnPool(cfg)
+ dialer := &openAIWSCountingDialer{}
+ pool.setClientDialerForTest(dialer)
+
+ account := &Account{ID: 123, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
+
+ lease1, err := pool.Acquire(context.Background(), openAIWSAcquireRequest{
+ Account: account,
+ WSURL: "wss://example.com/v1/responses",
+ })
+ require.NoError(t, err)
+ require.NotNil(t, lease1)
+ lease1.Release()
+
+ lease2, err := pool.Acquire(context.Background(), openAIWSAcquireRequest{
+ Account: account,
+ WSURL: "wss://example.com/v1/responses",
+ ForceNewConn: true,
+ })
+ require.NoError(t, err)
+ require.NotNil(t, lease2)
+ lease2.Release()
+
+ require.Equal(t, 2, dialer.DialCount(), "ForceNewConn=true 时应跳过空闲连接复用并新建连接")
+}
+
+func TestOpenAIWSConnPool_AcquireForcePreferredConnUnavailable(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2
+ cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
+ cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 2
+
+ pool := newOpenAIWSConnPool(cfg)
+ account := &Account{ID: 124, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
+ ap := pool.getOrCreateAccountPool(account.ID)
+ otherConn := newOpenAIWSConn("other_conn", account.ID, &openAIWSFakeConn{}, nil)
+ ap.mu.Lock()
+ ap.conns[otherConn.id] = otherConn
+ ap.mu.Unlock()
+
+ _, err := pool.Acquire(context.Background(), openAIWSAcquireRequest{
+ Account: account,
+ WSURL: "wss://example.com/v1/responses",
+ ForcePreferredConn: true,
+ })
+ require.ErrorIs(t, err, errOpenAIWSPreferredConnUnavailable)
+
+ _, err = pool.Acquire(context.Background(), openAIWSAcquireRequest{
+ Account: account,
+ WSURL: "wss://example.com/v1/responses",
+ PreferredConnID: "missing_conn",
+ ForcePreferredConn: true,
+ })
+ require.ErrorIs(t, err, errOpenAIWSPreferredConnUnavailable)
+}
+
+func TestOpenAIWSConnPool_AcquireForcePreferredConnQueuesOnPreferredOnly(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2
+ cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
+ cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 2
+ cfg.Gateway.OpenAIWS.QueueLimitPerConn = 4
+
+ pool := newOpenAIWSConnPool(cfg)
+ account := &Account{ID: 125, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
+ ap := pool.getOrCreateAccountPool(account.ID)
+ preferredConn := newOpenAIWSConn("preferred_conn", account.ID, &openAIWSFakeConn{}, nil)
+ otherConn := newOpenAIWSConn("other_conn_idle", account.ID, &openAIWSFakeConn{}, nil)
+ require.True(t, preferredConn.tryAcquire(), "先占用 preferred 连接,触发排队获取")
+ ap.mu.Lock()
+ ap.conns[preferredConn.id] = preferredConn
+ ap.conns[otherConn.id] = otherConn
+ ap.lastCleanupAt = time.Now()
+ ap.mu.Unlock()
+
+ go func() {
+ time.Sleep(60 * time.Millisecond)
+ preferredConn.release()
+ }()
+
+ ctx, cancel := context.WithTimeout(context.Background(), time.Second)
+ defer cancel()
+ lease, err := pool.Acquire(ctx, openAIWSAcquireRequest{
+ Account: account,
+ WSURL: "wss://example.com/v1/responses",
+ PreferredConnID: preferredConn.id,
+ ForcePreferredConn: true,
+ })
+ require.NoError(t, err)
+ require.NotNil(t, lease)
+ require.Equal(t, preferredConn.id, lease.ConnID(), "严格模式应只等待并复用 preferred 连接,不可漂移")
+ require.GreaterOrEqual(t, lease.QueueWaitDuration(), 40*time.Millisecond)
+ lease.Release()
+ require.True(t, otherConn.tryAcquire(), "other 连接不应被严格模式抢占")
+ otherConn.release()
+}
+
+func TestOpenAIWSConnPool_AcquireForcePreferredConnDirectAndQueueFull(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2
+ cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
+ cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 2
+ cfg.Gateway.OpenAIWS.QueueLimitPerConn = 1
+
+ pool := newOpenAIWSConnPool(cfg)
+ account := &Account{ID: 127, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
+ ap := pool.getOrCreateAccountPool(account.ID)
+ preferredConn := newOpenAIWSConn("preferred_conn_direct", account.ID, &openAIWSFakeConn{}, nil)
+ otherConn := newOpenAIWSConn("other_conn_direct", account.ID, &openAIWSFakeConn{}, nil)
+ ap.mu.Lock()
+ ap.conns[preferredConn.id] = preferredConn
+ ap.conns[otherConn.id] = otherConn
+ ap.lastCleanupAt = time.Now()
+ ap.mu.Unlock()
+
+ lease, err := pool.Acquire(context.Background(), openAIWSAcquireRequest{
+ Account: account,
+ WSURL: "wss://example.com/v1/responses",
+ PreferredConnID: preferredConn.id,
+ ForcePreferredConn: true,
+ })
+ require.NoError(t, err)
+ require.Equal(t, preferredConn.id, lease.ConnID(), "preferred 空闲时应直接命中")
+ lease.Release()
+
+ require.True(t, preferredConn.tryAcquire())
+ preferredConn.waiters.Store(1)
+ _, err = pool.Acquire(context.Background(), openAIWSAcquireRequest{
+ Account: account,
+ WSURL: "wss://example.com/v1/responses",
+ PreferredConnID: preferredConn.id,
+ ForcePreferredConn: true,
+ })
+ require.ErrorIs(t, err, errOpenAIWSConnQueueFull, "严格模式下队列满应直接失败,不得漂移")
+ preferredConn.waiters.Store(0)
+ preferredConn.release()
+}
+
+func TestOpenAIWSConnPool_CleanupSkipsPinnedConn(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2
+ cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 0
+
+ pool := newOpenAIWSConnPool(cfg)
+ accountID := int64(126)
+ ap := pool.getOrCreateAccountPool(accountID)
+ pinnedConn := newOpenAIWSConn("pinned_conn", accountID, &openAIWSFakeConn{}, nil)
+ idleConn := newOpenAIWSConn("idle_conn", accountID, &openAIWSFakeConn{}, nil)
+ ap.mu.Lock()
+ ap.conns[pinnedConn.id] = pinnedConn
+ ap.conns[idleConn.id] = idleConn
+ ap.mu.Unlock()
+
+ require.True(t, pool.PinConn(accountID, pinnedConn.id))
+ evicted := pool.cleanupAccountLocked(ap, time.Now(), pool.maxConnsHardCap())
+ closeOpenAIWSConns(evicted)
+
+ ap.mu.Lock()
+ _, pinnedExists := ap.conns[pinnedConn.id]
+ _, idleExists := ap.conns[idleConn.id]
+ ap.mu.Unlock()
+ require.True(t, pinnedExists, "被 active ingress 绑定的连接不应被 cleanup 回收")
+ require.False(t, idleExists, "非绑定的空闲连接应被回收")
+
+ pool.UnpinConn(accountID, pinnedConn.id)
+ evicted = pool.cleanupAccountLocked(ap, time.Now(), pool.maxConnsHardCap())
+ closeOpenAIWSConns(evicted)
+ ap.mu.Lock()
+ _, pinnedExists = ap.conns[pinnedConn.id]
+ ap.mu.Unlock()
+ require.False(t, pinnedExists, "解绑后连接应可被正常回收")
+}
+
+func TestOpenAIWSConnPool_PinUnpinConnBranches(t *testing.T) {
+ var nilPool *openAIWSConnPool
+ require.False(t, nilPool.PinConn(1, "x"))
+ nilPool.UnpinConn(1, "x")
+
+ cfg := &config.Config{}
+ pool := newOpenAIWSConnPool(cfg)
+ accountID := int64(128)
+ ap := &openAIWSAccountPool{
+ conns: map[string]*openAIWSConn{},
+ }
+ pool.accounts.Store(accountID, ap)
+
+ require.False(t, pool.PinConn(0, "x"))
+ require.False(t, pool.PinConn(999, "x"))
+ require.False(t, pool.PinConn(accountID, ""))
+ require.False(t, pool.PinConn(accountID, "missing"))
+
+ conn := newOpenAIWSConn("pin_refcount", accountID, &openAIWSFakeConn{}, nil)
+ ap.mu.Lock()
+ ap.conns[conn.id] = conn
+ ap.mu.Unlock()
+ require.True(t, pool.PinConn(accountID, conn.id))
+ require.True(t, pool.PinConn(accountID, conn.id))
+
+ ap.mu.Lock()
+ require.Equal(t, 2, ap.pinnedConns[conn.id])
+ ap.mu.Unlock()
+
+ pool.UnpinConn(accountID, conn.id)
+ ap.mu.Lock()
+ require.Equal(t, 1, ap.pinnedConns[conn.id])
+ ap.mu.Unlock()
+
+ pool.UnpinConn(accountID, conn.id)
+ ap.mu.Lock()
+ _, exists := ap.pinnedConns[conn.id]
+ ap.mu.Unlock()
+ require.False(t, exists)
+
+ pool.UnpinConn(accountID, conn.id)
+ pool.UnpinConn(accountID, "")
+ pool.UnpinConn(0, conn.id)
+ pool.UnpinConn(999, conn.id)
+}
+
+func TestOpenAIWSConnPool_EffectiveMaxConnsByAccount(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 8
+ cfg.Gateway.OpenAIWS.DynamicMaxConnsByAccountConcurrencyEnabled = true
+ cfg.Gateway.OpenAIWS.OAuthMaxConnsFactor = 1.0
+ cfg.Gateway.OpenAIWS.APIKeyMaxConnsFactor = 0.6
+
+ pool := newOpenAIWSConnPool(cfg)
+
+ oauthHigh := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth, Concurrency: 10}
+ require.Equal(t, 8, pool.effectiveMaxConnsByAccount(oauthHigh), "应受全局硬上限约束")
+
+ oauthLow := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth, Concurrency: 3}
+ require.Equal(t, 3, pool.effectiveMaxConnsByAccount(oauthLow))
+
+ apiKeyHigh := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Concurrency: 10}
+ require.Equal(t, 6, pool.effectiveMaxConnsByAccount(apiKeyHigh), "API Key 应按系数缩放")
+
+ apiKeyLow := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Concurrency: 1}
+ require.Equal(t, 1, pool.effectiveMaxConnsByAccount(apiKeyLow), "最小值应保持为 1")
+
+ unlimited := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth, Concurrency: 0}
+ require.Equal(t, 8, pool.effectiveMaxConnsByAccount(unlimited), "无限并发应回退到全局硬上限")
+
+ require.Equal(t, 8, pool.effectiveMaxConnsByAccount(nil), "缺少账号上下文应回退到全局硬上限")
+}
+
+func TestOpenAIWSConnPool_EffectiveMaxConnsDisabledFallbackHardCap(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 8
+ cfg.Gateway.OpenAIWS.DynamicMaxConnsByAccountConcurrencyEnabled = false
+ cfg.Gateway.OpenAIWS.OAuthMaxConnsFactor = 1.0
+ cfg.Gateway.OpenAIWS.APIKeyMaxConnsFactor = 1.0
+
+ pool := newOpenAIWSConnPool(cfg)
+ account := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth, Concurrency: 2}
+ require.Equal(t, 8, pool.effectiveMaxConnsByAccount(account), "关闭动态模式后应保持旧行为")
+}
+
+func TestOpenAIWSConnPool_EffectiveMaxConnsByAccount_ModeRouterV2UsesAccountConcurrency(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true
+ cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 8
+ cfg.Gateway.OpenAIWS.DynamicMaxConnsByAccountConcurrencyEnabled = true
+ cfg.Gateway.OpenAIWS.OAuthMaxConnsFactor = 0.3
+ cfg.Gateway.OpenAIWS.APIKeyMaxConnsFactor = 0.6
+
+ pool := newOpenAIWSConnPool(cfg)
+
+ high := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth, Concurrency: 20}
+ require.Equal(t, 20, pool.effectiveMaxConnsByAccount(high), "v2 路径应直接使用账号并发数作为池上限")
+
+ nonPositive := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Concurrency: 0}
+ require.Equal(t, 0, pool.effectiveMaxConnsByAccount(nonPositive), "并发数<=0 时应不可调度")
+}
+
+func TestOpenAIWSConnPool_AcquireRejectsWhenEffectiveMaxConnsIsZero(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true
+ cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 8
+ pool := newOpenAIWSConnPool(cfg)
+
+ account := &Account{ID: 901, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Concurrency: 0}
+ _, err := pool.Acquire(context.Background(), openAIWSAcquireRequest{
+ Account: account,
+ WSURL: "wss://example.com/v1/responses",
+ })
+ require.ErrorIs(t, err, errOpenAIWSConnQueueFull)
+}
+
+func TestOpenAIWSConnLease_ReadMessageWithContextTimeout_PerRead(t *testing.T) {
+ conn := newOpenAIWSConn("timeout", 1, &openAIWSBlockingConn{readDelay: 80 * time.Millisecond}, nil)
+ lease := &openAIWSConnLease{conn: conn}
+
+ _, err := lease.ReadMessageWithContextTimeout(context.Background(), 20*time.Millisecond)
+ require.Error(t, err)
+ require.ErrorIs(t, err, context.DeadlineExceeded)
+
+ payload, err := lease.ReadMessageWithContextTimeout(context.Background(), 150*time.Millisecond)
+ require.NoError(t, err)
+ require.Contains(t, string(payload), "response.completed")
+
+ parentCtx, cancel := context.WithCancel(context.Background())
+ cancel()
+ _, err = lease.ReadMessageWithContextTimeout(parentCtx, 150*time.Millisecond)
+ require.Error(t, err)
+ require.ErrorIs(t, err, context.Canceled)
+}
+
+func TestOpenAIWSConnLease_WriteJSONWithContextTimeout_RespectsParentContext(t *testing.T) {
+ conn := newOpenAIWSConn("write_timeout_ctx", 1, &openAIWSWriteBlockingConn{}, nil)
+ lease := &openAIWSConnLease{conn: conn}
+
+ parentCtx, cancel := context.WithCancel(context.Background())
+ go func() {
+ time.Sleep(20 * time.Millisecond)
+ cancel()
+ }()
+
+ start := time.Now()
+ err := lease.WriteJSONWithContextTimeout(parentCtx, map[string]any{"type": "response.create"}, 2*time.Minute)
+ elapsed := time.Since(start)
+
+ require.Error(t, err)
+ require.ErrorIs(t, err, context.Canceled)
+ require.Less(t, elapsed, 200*time.Millisecond)
+}
+
+func TestOpenAIWSConnLease_PingWithTimeout(t *testing.T) {
+ conn := newOpenAIWSConn("ping_ok", 1, &openAIWSFakeConn{}, nil)
+ lease := &openAIWSConnLease{conn: conn}
+ require.NoError(t, lease.PingWithTimeout(50*time.Millisecond))
+
+ var nilLease *openAIWSConnLease
+ err := nilLease.PingWithTimeout(50 * time.Millisecond)
+ require.ErrorIs(t, err, errOpenAIWSConnClosed)
+}
+
+func TestOpenAIWSConn_ReadAndWriteCanProceedConcurrently(t *testing.T) {
+ conn := newOpenAIWSConn("full_duplex", 1, &openAIWSBlockingConn{readDelay: 120 * time.Millisecond}, nil)
+
+ readDone := make(chan error, 1)
+ go func() {
+ _, err := conn.readMessageWithContextTimeout(context.Background(), 200*time.Millisecond)
+ readDone <- err
+ }()
+
+ // 让读取先占用 readMu。
+ time.Sleep(20 * time.Millisecond)
+
+ start := time.Now()
+ err := conn.pingWithTimeout(50 * time.Millisecond)
+ elapsed := time.Since(start)
+
+ require.NoError(t, err)
+ require.Less(t, elapsed, 80*time.Millisecond, "写路径不应被读锁长期阻塞")
+ require.NoError(t, <-readDone)
+}
+
+func TestOpenAIWSConnPool_BackgroundPingSweep_EvictsDeadIdleConn(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2
+ pool := newOpenAIWSConnPool(cfg)
+
+ accountID := int64(301)
+ ap := pool.getOrCreateAccountPool(accountID)
+ conn := newOpenAIWSConn("dead_idle", accountID, &openAIWSPingFailConn{}, nil)
+ ap.mu.Lock()
+ ap.conns[conn.id] = conn
+ ap.mu.Unlock()
+
+ pool.runBackgroundPingSweep()
+
+ ap.mu.Lock()
+ _, exists := ap.conns[conn.id]
+ ap.mu.Unlock()
+ require.False(t, exists, "后台 ping 失败的空闲连接应被回收")
+}
+
+func TestOpenAIWSConnPool_BackgroundCleanupSweep_WithoutAcquire(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2
+ cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 2
+ pool := newOpenAIWSConnPool(cfg)
+
+ accountID := int64(302)
+ ap := pool.getOrCreateAccountPool(accountID)
+ stale := newOpenAIWSConn("stale_bg", accountID, &openAIWSFakeConn{}, nil)
+ stale.createdAtNano.Store(time.Now().Add(-2 * time.Hour).UnixNano())
+ stale.lastUsedNano.Store(time.Now().Add(-2 * time.Hour).UnixNano())
+ ap.mu.Lock()
+ ap.conns[stale.id] = stale
+ ap.mu.Unlock()
+
+ pool.runBackgroundCleanupSweep(time.Now())
+
+ ap.mu.Lock()
+ _, exists := ap.conns[stale.id]
+ ap.mu.Unlock()
+ require.False(t, exists, "后台清理应在无新 acquire 时也回收过期连接")
+}
+
+func TestOpenAIWSConnPool_BackgroundWorkerGuardBranches(t *testing.T) {
+ var nilPool *openAIWSConnPool
+ require.NotPanics(t, func() {
+ nilPool.startBackgroundWorkers()
+ nilPool.runBackgroundPingWorker()
+ nilPool.runBackgroundPingSweep()
+ _ = nilPool.snapshotIdleConnsForPing()
+ nilPool.runBackgroundCleanupWorker()
+ nilPool.runBackgroundCleanupSweep(time.Now())
+ })
+
+ poolNoStop := &openAIWSConnPool{}
+ require.NotPanics(t, func() {
+ poolNoStop.startBackgroundWorkers()
+ })
+
+ poolStopPing := &openAIWSConnPool{workerStopCh: make(chan struct{})}
+ pingDone := make(chan struct{})
+ go func() {
+ poolStopPing.runBackgroundPingWorker()
+ close(pingDone)
+ }()
+ close(poolStopPing.workerStopCh)
+ select {
+ case <-pingDone:
+ case <-time.After(500 * time.Millisecond):
+ t.Fatal("runBackgroundPingWorker 未在 stop 信号后退出")
+ }
+
+ poolStopCleanup := &openAIWSConnPool{workerStopCh: make(chan struct{})}
+ cleanupDone := make(chan struct{})
+ go func() {
+ poolStopCleanup.runBackgroundCleanupWorker()
+ close(cleanupDone)
+ }()
+ close(poolStopCleanup.workerStopCh)
+ select {
+ case <-cleanupDone:
+ case <-time.After(500 * time.Millisecond):
+ t.Fatal("runBackgroundCleanupWorker 未在 stop 信号后退出")
+ }
+}
+
+func TestOpenAIWSConnPool_SnapshotIdleConnsForPing_SkipsInvalidEntries(t *testing.T) {
+ pool := &openAIWSConnPool{}
+ pool.accounts.Store("invalid-key", &openAIWSAccountPool{})
+ pool.accounts.Store(int64(123), "invalid-value")
+
+ accountID := int64(123)
+ ap := &openAIWSAccountPool{
+ conns: make(map[string]*openAIWSConn),
+ }
+ ap.conns["nil_conn"] = nil
+
+ leased := newOpenAIWSConn("leased", accountID, &openAIWSFakeConn{}, nil)
+ require.True(t, leased.tryAcquire())
+ ap.conns[leased.id] = leased
+
+ waiting := newOpenAIWSConn("waiting", accountID, &openAIWSFakeConn{}, nil)
+ waiting.waiters.Store(1)
+ ap.conns[waiting.id] = waiting
+
+ idle := newOpenAIWSConn("idle", accountID, &openAIWSFakeConn{}, nil)
+ ap.conns[idle.id] = idle
+
+ pool.accounts.Store(accountID, ap)
+ candidates := pool.snapshotIdleConnsForPing()
+ require.Len(t, candidates, 1)
+ require.Equal(t, idle.id, candidates[0].conn.id)
+}
+
+func TestOpenAIWSConnPool_RunBackgroundCleanupSweep_SkipsInvalidAndUsesAccountCap(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 4
+ cfg.Gateway.OpenAIWS.DynamicMaxConnsByAccountConcurrencyEnabled = true
+
+ pool := &openAIWSConnPool{cfg: cfg}
+ pool.accounts.Store("bad-key", "bad-value")
+
+ accountID := int64(2026)
+ ap := &openAIWSAccountPool{
+ conns: make(map[string]*openAIWSConn),
+ }
+ ap.conns["nil_conn"] = nil
+ stale := newOpenAIWSConn("stale_bg_cleanup", accountID, &openAIWSFakeConn{}, nil)
+ stale.createdAtNano.Store(time.Now().Add(-2 * time.Hour).UnixNano())
+ stale.lastUsedNano.Store(time.Now().Add(-2 * time.Hour).UnixNano())
+ ap.conns[stale.id] = stale
+ ap.lastAcquire = &openAIWSAcquireRequest{
+ Account: &Account{
+ ID: accountID,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Concurrency: 1,
+ },
+ }
+ pool.accounts.Store(accountID, ap)
+
+ now := time.Now()
+ require.NotPanics(t, func() {
+ pool.runBackgroundCleanupSweep(now)
+ })
+
+ ap.mu.Lock()
+ _, nilConnExists := ap.conns["nil_conn"]
+ _, exists := ap.conns[stale.id]
+ lastCleanupAt := ap.lastCleanupAt
+ ap.mu.Unlock()
+
+ require.False(t, nilConnExists, "后台清理应移除无效 nil 连接条目")
+ require.False(t, exists, "后台清理应清理过期连接")
+ require.Equal(t, now, lastCleanupAt)
+}
+
+func TestOpenAIWSConnPool_QueueLimitPerConn_DefaultAndConfigured(t *testing.T) {
+ var nilPool *openAIWSConnPool
+ require.Equal(t, 256, nilPool.queueLimitPerConn())
+
+ pool := &openAIWSConnPool{cfg: &config.Config{}}
+ require.Equal(t, 256, pool.queueLimitPerConn())
+
+ pool.cfg.Gateway.OpenAIWS.QueueLimitPerConn = 9
+ require.Equal(t, 9, pool.queueLimitPerConn())
+}
+
+func TestOpenAIWSConnPool_Close(t *testing.T) {
+ cfg := &config.Config{}
+ pool := newOpenAIWSConnPool(cfg)
+
+ // Close 应该可以安全调用
+ pool.Close()
+
+ // workerStopCh 应已关闭
+ select {
+ case <-pool.workerStopCh:
+ // 预期:channel 已关闭
+ default:
+ t.Fatal("Close 后 workerStopCh 应已关闭")
+ }
+
+ // 多次调用 Close 不应 panic
+ pool.Close()
+
+ // nil pool 调用 Close 不应 panic
+ var nilPool *openAIWSConnPool
+ nilPool.Close()
+}
+
+func TestOpenAIWSDialError_ErrorAndUnwrap(t *testing.T) {
+ baseErr := errors.New("boom")
+ dialErr := &openAIWSDialError{StatusCode: 502, Err: baseErr}
+ require.Contains(t, dialErr.Error(), "status=502")
+ require.ErrorIs(t, dialErr.Unwrap(), baseErr)
+
+ noStatus := &openAIWSDialError{Err: baseErr}
+ require.Contains(t, noStatus.Error(), "boom")
+
+ var nilDialErr *openAIWSDialError
+ require.Equal(t, "", nilDialErr.Error())
+ require.NoError(t, nilDialErr.Unwrap())
+}
+
+func TestOpenAIWSConnLease_ReadWriteHelpersAndConnStats(t *testing.T) {
+ conn := newOpenAIWSConn("helper_conn", 1, &openAIWSFakeConn{}, http.Header{
+ "X-Test": []string{" value "},
+ })
+ lease := &openAIWSConnLease{conn: conn}
+
+ require.NoError(t, lease.WriteJSONContext(context.Background(), map[string]any{"type": "response.create"}))
+ payload, err := lease.ReadMessage(100 * time.Millisecond)
+ require.NoError(t, err)
+ require.Contains(t, string(payload), "response.completed")
+
+ payload, err = lease.ReadMessageContext(context.Background())
+ require.NoError(t, err)
+ require.Contains(t, string(payload), "response.completed")
+
+ payload, err = conn.readMessageWithTimeout(100 * time.Millisecond)
+ require.NoError(t, err)
+ require.Contains(t, string(payload), "response.completed")
+
+ require.Equal(t, "value", conn.handshakeHeader(" X-Test "))
+ require.NotZero(t, conn.createdAt())
+ require.NotZero(t, conn.lastUsedAt())
+ require.GreaterOrEqual(t, conn.age(time.Now()), time.Duration(0))
+ require.GreaterOrEqual(t, conn.idleDuration(time.Now()), time.Duration(0))
+ require.False(t, conn.isLeased())
+
+ // 覆盖空上下文路径
+ _, err = conn.readMessage(context.Background())
+ require.NoError(t, err)
+
+ // 覆盖 nil 保护分支
+ var nilConn *openAIWSConn
+ require.ErrorIs(t, nilConn.writeJSONWithTimeout(context.Background(), map[string]any{}, time.Second), errOpenAIWSConnClosed)
+ _, err = nilConn.readMessageWithTimeout(10 * time.Millisecond)
+ require.ErrorIs(t, err, errOpenAIWSConnClosed)
+ _, err = nilConn.readMessageWithContextTimeout(context.Background(), 10*time.Millisecond)
+ require.ErrorIs(t, err, errOpenAIWSConnClosed)
+}
+
+func TestOpenAIWSConnPool_PickOldestIdleAndAccountPoolLoad(t *testing.T) {
+ pool := &openAIWSConnPool{}
+ accountID := int64(404)
+ ap := &openAIWSAccountPool{conns: map[string]*openAIWSConn{}}
+
+ idleOld := newOpenAIWSConn("idle_old", accountID, &openAIWSFakeConn{}, nil)
+ idleOld.lastUsedNano.Store(time.Now().Add(-10 * time.Minute).UnixNano())
+ idleNew := newOpenAIWSConn("idle_new", accountID, &openAIWSFakeConn{}, nil)
+ idleNew.lastUsedNano.Store(time.Now().Add(-1 * time.Minute).UnixNano())
+ leased := newOpenAIWSConn("leased", accountID, &openAIWSFakeConn{}, nil)
+ require.True(t, leased.tryAcquire())
+ leased.waiters.Store(2)
+
+ ap.conns[idleOld.id] = idleOld
+ ap.conns[idleNew.id] = idleNew
+ ap.conns[leased.id] = leased
+
+ oldest := pool.pickOldestIdleConnLocked(ap)
+ require.NotNil(t, oldest)
+ require.Equal(t, idleOld.id, oldest.id)
+
+ inflight, waiters := accountPoolLoadLocked(ap)
+ require.Equal(t, 1, inflight)
+ require.Equal(t, 2, waiters)
+
+ pool.accounts.Store(accountID, ap)
+ loadInflight, loadWaiters, conns := pool.AccountPoolLoad(accountID)
+ require.Equal(t, 1, loadInflight)
+ require.Equal(t, 2, loadWaiters)
+ require.Equal(t, 3, conns)
+
+ zeroInflight, zeroWaiters, zeroConns := pool.AccountPoolLoad(0)
+ require.Equal(t, 0, zeroInflight)
+ require.Equal(t, 0, zeroWaiters)
+ require.Equal(t, 0, zeroConns)
+}
+
+func TestOpenAIWSConnPool_Close_WaitsWorkerGroupAndNilStopChannel(t *testing.T) {
+ pool := &openAIWSConnPool{}
+ release := make(chan struct{})
+ pool.workerWg.Add(1)
+ go func() {
+ defer pool.workerWg.Done()
+ <-release
+ }()
+
+ closed := make(chan struct{})
+ go func() {
+ pool.Close()
+ close(closed)
+ }()
+
+ select {
+ case <-closed:
+ t.Fatal("Close 不应在 WaitGroup 未完成时提前返回")
+ case <-time.After(30 * time.Millisecond):
+ }
+
+ close(release)
+ select {
+ case <-closed:
+ case <-time.After(time.Second):
+ t.Fatal("Close 未等待 workerWg 完成")
+ }
+}
+
+func TestOpenAIWSConnPool_Close_ClosesOnlyIdleConnections(t *testing.T) {
+ pool := &openAIWSConnPool{
+ workerStopCh: make(chan struct{}),
+ }
+
+ accountID := int64(606)
+ ap := &openAIWSAccountPool{
+ conns: map[string]*openAIWSConn{},
+ }
+ idle := newOpenAIWSConn("idle_conn", accountID, &openAIWSFakeConn{}, nil)
+ leased := newOpenAIWSConn("leased_conn", accountID, &openAIWSFakeConn{}, nil)
+ require.True(t, leased.tryAcquire())
+
+ ap.conns[idle.id] = idle
+ ap.conns[leased.id] = leased
+ pool.accounts.Store(accountID, ap)
+ pool.accounts.Store("invalid-key", "invalid-value")
+
+ pool.Close()
+
+ select {
+ case <-idle.closedCh:
+ // idle should be closed
+ default:
+ t.Fatal("空闲连接应在 Close 时被关闭")
+ }
+
+ select {
+ case <-leased.closedCh:
+ t.Fatal("已租赁连接不应在 Close 时被关闭")
+ default:
+ }
+
+ leased.release()
+ pool.Close()
+}
+
+func TestOpenAIWSConnPool_RunBackgroundPingSweep_ConcurrencyLimit(t *testing.T) {
+ cfg := &config.Config{}
+ pool := newOpenAIWSConnPool(cfg)
+ accountID := int64(505)
+ ap := pool.getOrCreateAccountPool(accountID)
+
+ var current atomic.Int32
+ var maxConcurrent atomic.Int32
+ release := make(chan struct{})
+ for i := 0; i < 25; i++ {
+ conn := newOpenAIWSConn(pool.nextConnID(accountID), accountID, &openAIWSPingBlockingConn{
+ current: ¤t,
+ maxConcurrent: &maxConcurrent,
+ release: release,
+ }, nil)
+ ap.mu.Lock()
+ ap.conns[conn.id] = conn
+ ap.mu.Unlock()
+ }
+
+ done := make(chan struct{})
+ go func() {
+ pool.runBackgroundPingSweep()
+ close(done)
+ }()
+
+ require.Eventually(t, func() bool {
+ return maxConcurrent.Load() >= 10
+ }, time.Second, 10*time.Millisecond)
+
+ close(release)
+ select {
+ case <-done:
+ case <-time.After(2 * time.Second):
+ t.Fatal("runBackgroundPingSweep 未在释放后完成")
+ }
+
+ require.LessOrEqual(t, maxConcurrent.Load(), int32(10))
+}
+
+func TestOpenAIWSConnLease_BasicGetterBranches(t *testing.T) {
+ var nilLease *openAIWSConnLease
+ require.Equal(t, "", nilLease.ConnID())
+ require.Equal(t, time.Duration(0), nilLease.QueueWaitDuration())
+ require.Equal(t, time.Duration(0), nilLease.ConnPickDuration())
+ require.False(t, nilLease.Reused())
+ require.Equal(t, "", nilLease.HandshakeHeader("x-test"))
+ require.False(t, nilLease.IsPrewarmed())
+ nilLease.MarkPrewarmed()
+ nilLease.Release()
+
+ conn := newOpenAIWSConn("getter_conn", 1, &openAIWSFakeConn{}, http.Header{"X-Test": []string{"ok"}})
+ lease := &openAIWSConnLease{
+ conn: conn,
+ queueWait: 3 * time.Millisecond,
+ connPick: 4 * time.Millisecond,
+ reused: true,
+ }
+ require.Equal(t, "getter_conn", lease.ConnID())
+ require.Equal(t, 3*time.Millisecond, lease.QueueWaitDuration())
+ require.Equal(t, 4*time.Millisecond, lease.ConnPickDuration())
+ require.True(t, lease.Reused())
+ require.Equal(t, "ok", lease.HandshakeHeader("x-test"))
+ require.False(t, lease.IsPrewarmed())
+ lease.MarkPrewarmed()
+ require.True(t, lease.IsPrewarmed())
+ lease.Release()
+}
+
+func TestOpenAIWSConnPool_UtilityBranches(t *testing.T) {
+ var nilPool *openAIWSConnPool
+ require.Equal(t, OpenAIWSPoolMetricsSnapshot{}, nilPool.SnapshotMetrics())
+ require.Equal(t, OpenAIWSTransportMetricsSnapshot{}, nilPool.SnapshotTransportMetrics())
+
+ pool := &openAIWSConnPool{cfg: &config.Config{}}
+ pool.metrics.acquireTotal.Store(7)
+ pool.metrics.acquireReuseTotal.Store(3)
+ metrics := pool.SnapshotMetrics()
+ require.Equal(t, int64(7), metrics.AcquireTotal)
+ require.Equal(t, int64(3), metrics.AcquireReuseTotal)
+
+ // 非 transport metrics dialer 路径
+ pool.clientDialer = &openAIWSFakeDialer{}
+ require.Equal(t, OpenAIWSTransportMetricsSnapshot{}, pool.SnapshotTransportMetrics())
+ pool.setClientDialerForTest(nil)
+ require.NotNil(t, pool.clientDialer)
+
+ require.Equal(t, 8, nilPool.maxConnsHardCap())
+ require.False(t, nilPool.dynamicMaxConnsEnabled())
+ require.Equal(t, 1.0, nilPool.maxConnsFactorByAccount(nil))
+ require.Equal(t, 0, nilPool.minIdlePerAccount())
+ require.Equal(t, 4, nilPool.maxIdlePerAccount())
+ require.Equal(t, 256, nilPool.queueLimitPerConn())
+ require.Equal(t, 0.7, nilPool.targetUtilization())
+ require.Equal(t, time.Duration(0), nilPool.prewarmCooldown())
+ require.Equal(t, 10*time.Second, nilPool.dialTimeout())
+
+ // shouldSuppressPrewarmLocked 覆盖 3 条分支
+ now := time.Now()
+ apNilFail := &openAIWSAccountPool{prewarmFails: 1}
+ require.False(t, pool.shouldSuppressPrewarmLocked(apNilFail, now))
+ apZeroTime := &openAIWSAccountPool{prewarmFails: 2}
+ require.False(t, pool.shouldSuppressPrewarmLocked(apZeroTime, now))
+ require.Equal(t, 0, apZeroTime.prewarmFails)
+ apOldFail := &openAIWSAccountPool{prewarmFails: 2, prewarmFailAt: now.Add(-openAIWSPrewarmFailureWindow - time.Second)}
+ require.False(t, pool.shouldSuppressPrewarmLocked(apOldFail, now))
+ apRecentFail := &openAIWSAccountPool{prewarmFails: openAIWSPrewarmFailureSuppress, prewarmFailAt: now}
+ require.True(t, pool.shouldSuppressPrewarmLocked(apRecentFail, now))
+
+ // recordConnPickDuration 的保护分支
+ nilPool.recordConnPickDuration(10 * time.Millisecond)
+ pool.recordConnPickDuration(-10 * time.Millisecond)
+ require.Equal(t, int64(1), pool.metrics.connPickTotal.Load())
+
+ // account pool 读写分支
+ require.Nil(t, nilPool.getOrCreateAccountPool(1))
+ require.Nil(t, pool.getOrCreateAccountPool(0))
+ pool.accounts.Store(int64(7), "invalid")
+ ap := pool.getOrCreateAccountPool(7)
+ require.NotNil(t, ap)
+ _, ok := pool.getAccountPool(0)
+ require.False(t, ok)
+ _, ok = pool.getAccountPool(12345)
+ require.False(t, ok)
+ pool.accounts.Store(int64(8), "bad-type")
+ _, ok = pool.getAccountPool(8)
+ require.False(t, ok)
+
+ // health check 条件
+ require.False(t, pool.shouldHealthCheckConn(nil))
+ conn := newOpenAIWSConn("health", 1, &openAIWSFakeConn{}, nil)
+ conn.lastUsedNano.Store(time.Now().Add(-openAIWSConnHealthCheckIdle - time.Second).UnixNano())
+ require.True(t, pool.shouldHealthCheckConn(conn))
+}
+
+func TestOpenAIWSConn_LeaseAndTimeHelpers_NilAndClosedBranches(t *testing.T) {
+ var nilConn *openAIWSConn
+ nilConn.touch()
+ require.Equal(t, time.Time{}, nilConn.createdAt())
+ require.Equal(t, time.Time{}, nilConn.lastUsedAt())
+ require.Equal(t, time.Duration(0), nilConn.idleDuration(time.Now()))
+ require.Equal(t, time.Duration(0), nilConn.age(time.Now()))
+ require.False(t, nilConn.isLeased())
+ require.False(t, nilConn.isPrewarmed())
+ nilConn.markPrewarmed()
+
+ conn := newOpenAIWSConn("lease_state", 1, &openAIWSFakeConn{}, nil)
+ require.True(t, conn.tryAcquire())
+ require.True(t, conn.isLeased())
+ conn.release()
+ require.False(t, conn.isLeased())
+ conn.close()
+ require.False(t, conn.tryAcquire())
+
+ ctx, cancel := context.WithCancel(context.Background())
+ cancel()
+ err := conn.acquire(ctx)
+ require.Error(t, err)
+}
+
+func TestOpenAIWSConnLease_ReadWriteNilConnBranches(t *testing.T) {
+ lease := &openAIWSConnLease{}
+ require.ErrorIs(t, lease.WriteJSON(map[string]any{"k": "v"}, time.Second), errOpenAIWSConnClosed)
+ require.ErrorIs(t, lease.WriteJSONContext(context.Background(), map[string]any{"k": "v"}), errOpenAIWSConnClosed)
+ _, err := lease.ReadMessage(10 * time.Millisecond)
+ require.ErrorIs(t, err, errOpenAIWSConnClosed)
+ _, err = lease.ReadMessageContext(context.Background())
+ require.ErrorIs(t, err, errOpenAIWSConnClosed)
+ _, err = lease.ReadMessageWithContextTimeout(context.Background(), 10*time.Millisecond)
+ require.ErrorIs(t, err, errOpenAIWSConnClosed)
+}
+
+func TestOpenAIWSConnLease_ReleasedLeaseGuards(t *testing.T) {
+ conn := newOpenAIWSConn("released_guard", 1, &openAIWSFakeConn{}, nil)
+ lease := &openAIWSConnLease{conn: conn}
+
+ require.NoError(t, lease.PingWithTimeout(50*time.Millisecond))
+
+ lease.Release()
+ lease.Release() // idempotent
+
+ require.ErrorIs(t, lease.WriteJSON(map[string]any{"k": "v"}, time.Second), errOpenAIWSConnClosed)
+ require.ErrorIs(t, lease.WriteJSONContext(context.Background(), map[string]any{"k": "v"}), errOpenAIWSConnClosed)
+ require.ErrorIs(t, lease.WriteJSONWithContextTimeout(context.Background(), map[string]any{"k": "v"}, time.Second), errOpenAIWSConnClosed)
+
+ _, err := lease.ReadMessage(10 * time.Millisecond)
+ require.ErrorIs(t, err, errOpenAIWSConnClosed)
+ _, err = lease.ReadMessageContext(context.Background())
+ require.ErrorIs(t, err, errOpenAIWSConnClosed)
+ _, err = lease.ReadMessageWithContextTimeout(context.Background(), 10*time.Millisecond)
+ require.ErrorIs(t, err, errOpenAIWSConnClosed)
+
+ require.ErrorIs(t, lease.PingWithTimeout(50*time.Millisecond), errOpenAIWSConnClosed)
+}
+
+func TestOpenAIWSConnLease_MarkBrokenAfterRelease_NoEviction(t *testing.T) {
+ conn := newOpenAIWSConn("released_markbroken", 7, &openAIWSFakeConn{}, nil)
+ ap := &openAIWSAccountPool{
+ conns: map[string]*openAIWSConn{
+ conn.id: conn,
+ },
+ }
+ pool := &openAIWSConnPool{}
+ pool.accounts.Store(int64(7), ap)
+
+ lease := &openAIWSConnLease{
+ pool: pool,
+ accountID: 7,
+ conn: conn,
+ }
+
+ lease.Release()
+ lease.MarkBroken()
+
+ ap.mu.Lock()
+ _, exists := ap.conns[conn.id]
+ ap.mu.Unlock()
+ require.True(t, exists, "released lease should not evict active pool connection")
+}
+
+func TestOpenAIWSConn_AdditionalGuardBranches(t *testing.T) {
+ var nilConn *openAIWSConn
+ require.False(t, nilConn.tryAcquire())
+ require.ErrorIs(t, nilConn.acquire(context.Background()), errOpenAIWSConnClosed)
+ nilConn.release()
+ nilConn.close()
+ require.Equal(t, "", nilConn.handshakeHeader("x-test"))
+
+ connBusy := newOpenAIWSConn("busy_ctx", 1, &openAIWSFakeConn{}, nil)
+ require.True(t, connBusy.tryAcquire())
+ ctx, cancel := context.WithCancel(context.Background())
+ cancel()
+ require.ErrorIs(t, connBusy.acquire(ctx), context.Canceled)
+ connBusy.release()
+
+ connClosed := newOpenAIWSConn("closed_guard", 1, &openAIWSFakeConn{}, nil)
+ connClosed.close()
+ require.ErrorIs(
+ t,
+ connClosed.writeJSONWithTimeout(context.Background(), map[string]any{"k": "v"}, time.Second),
+ errOpenAIWSConnClosed,
+ )
+ _, err := connClosed.readMessageWithContextTimeout(context.Background(), time.Second)
+ require.ErrorIs(t, err, errOpenAIWSConnClosed)
+ require.ErrorIs(t, connClosed.pingWithTimeout(time.Second), errOpenAIWSConnClosed)
+
+ connNoWS := newOpenAIWSConn("no_ws", 1, nil, nil)
+ require.ErrorIs(t, connNoWS.writeJSON(map[string]any{"k": "v"}, context.Background()), errOpenAIWSConnClosed)
+ _, err = connNoWS.readMessage(context.Background())
+ require.ErrorIs(t, err, errOpenAIWSConnClosed)
+ require.ErrorIs(t, connNoWS.pingWithTimeout(time.Second), errOpenAIWSConnClosed)
+ require.Equal(t, "", connNoWS.handshakeHeader("x-test"))
+
+ connOK := newOpenAIWSConn("ok", 1, &openAIWSFakeConn{}, nil)
+ require.NoError(t, connOK.writeJSON(map[string]any{"k": "v"}, nil))
+ _, err = connOK.readMessageWithContextTimeout(context.Background(), 0)
+ require.NoError(t, err)
+ require.NoError(t, connOK.pingWithTimeout(0))
+
+ connZero := newOpenAIWSConn("zero_ts", 1, &openAIWSFakeConn{}, nil)
+ connZero.createdAtNano.Store(0)
+ connZero.lastUsedNano.Store(0)
+ require.True(t, connZero.createdAt().IsZero())
+ require.True(t, connZero.lastUsedAt().IsZero())
+ require.Equal(t, time.Duration(0), connZero.idleDuration(time.Now()))
+ require.Equal(t, time.Duration(0), connZero.age(time.Now()))
+
+ require.Nil(t, cloneOpenAIWSAcquireRequestPtr(nil))
+ copied := cloneHeader(http.Header{
+ "X-Empty": []string{},
+ "X-Test": []string{"v1"},
+ })
+ require.Contains(t, copied, "X-Empty")
+ require.Nil(t, copied["X-Empty"])
+ require.Equal(t, "v1", copied.Get("X-Test"))
+
+ closeOpenAIWSConns([]*openAIWSConn{nil, connOK})
+}
+
+func TestOpenAIWSConnLease_MarkBrokenEvictsConn(t *testing.T) {
+ pool := newOpenAIWSConnPool(&config.Config{})
+ accountID := int64(5001)
+ conn := newOpenAIWSConn("broken_me", accountID, &openAIWSFakeConn{}, nil)
+ ap := pool.getOrCreateAccountPool(accountID)
+ ap.mu.Lock()
+ ap.conns[conn.id] = conn
+ ap.mu.Unlock()
+
+ lease := &openAIWSConnLease{
+ pool: pool,
+ accountID: accountID,
+ conn: conn,
+ }
+ lease.MarkBroken()
+
+ ap.mu.Lock()
+ _, exists := ap.conns[conn.id]
+ ap.mu.Unlock()
+ require.False(t, exists)
+ require.False(t, conn.tryAcquire(), "被标记为 broken 的连接应被关闭")
+}
+
+func TestOpenAIWSConnPool_TargetConnCountAndPrewarmBranches(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
+ pool := newOpenAIWSConnPool(cfg)
+
+ require.Equal(t, 0, pool.targetConnCountLocked(nil, 1))
+ ap := &openAIWSAccountPool{conns: map[string]*openAIWSConn{}}
+ require.Equal(t, 0, pool.targetConnCountLocked(ap, 0))
+
+ cfg.Gateway.OpenAIWS.MinIdlePerAccount = 3
+ require.Equal(t, 1, pool.targetConnCountLocked(ap, 1), "minIdle 应被 maxConns 截断")
+
+ // 覆盖 waiters>0 且 target 需要至少 len(conns)+1 的分支
+ cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
+ cfg.Gateway.OpenAIWS.PoolTargetUtilization = 0.9
+ busy := newOpenAIWSConn("busy_target", 2, &openAIWSFakeConn{}, nil)
+ require.True(t, busy.tryAcquire())
+ busy.waiters.Store(1)
+ ap.conns[busy.id] = busy
+ target := pool.targetConnCountLocked(ap, 4)
+ require.GreaterOrEqual(t, target, len(ap.conns)+1)
+
+ // prewarm: account pool 缺失时,拨号后的连接应被关闭并提前返回
+ req := openAIWSAcquireRequest{
+ Account: &Account{ID: 999, Platform: PlatformOpenAI, Type: AccountTypeAPIKey},
+ WSURL: "wss://example.com/v1/responses",
+ }
+ pool.prewarmConns(999, req, 1)
+
+ // prewarm: 拨号失败分支(prewarmFails 累加)
+ accountID := int64(1000)
+ failPool := newOpenAIWSConnPool(cfg)
+ failPool.setClientDialerForTest(&openAIWSAlwaysFailDialer{})
+ apFail := failPool.getOrCreateAccountPool(accountID)
+ apFail.mu.Lock()
+ apFail.creating = 1
+ apFail.mu.Unlock()
+ req.Account.ID = accountID
+ failPool.prewarmConns(accountID, req, 1)
+ apFail.mu.Lock()
+ require.GreaterOrEqual(t, apFail.prewarmFails, 1)
+ apFail.mu.Unlock()
+}
+
+func TestOpenAIWSConnPool_Acquire_ErrorBranches(t *testing.T) {
+ var nilPool *openAIWSConnPool
+ _, err := nilPool.Acquire(context.Background(), openAIWSAcquireRequest{})
+ require.Error(t, err)
+
+ pool := newOpenAIWSConnPool(&config.Config{})
+ _, err = pool.Acquire(context.Background(), openAIWSAcquireRequest{
+ Account: &Account{ID: 1},
+ WSURL: " ",
+ })
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "ws url is empty")
+
+ // target=nil 分支:池满且仅有 nil 连接
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
+ cfg.Gateway.OpenAIWS.QueueLimitPerConn = 1
+ fullPool := newOpenAIWSConnPool(cfg)
+ account := &Account{ID: 2001, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
+ ap := fullPool.getOrCreateAccountPool(account.ID)
+ ap.mu.Lock()
+ ap.conns["nil"] = nil
+ ap.lastCleanupAt = time.Now()
+ ap.mu.Unlock()
+ _, err = fullPool.Acquire(context.Background(), openAIWSAcquireRequest{
+ Account: account,
+ WSURL: "wss://example.com/v1/responses",
+ })
+ require.ErrorIs(t, err, errOpenAIWSConnClosed)
+
+ // queue full 分支:waiters 达上限
+ account2 := &Account{ID: 2002, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
+ ap2 := fullPool.getOrCreateAccountPool(account2.ID)
+ conn := newOpenAIWSConn("queue_full", account2.ID, &openAIWSFakeConn{}, nil)
+ require.True(t, conn.tryAcquire())
+ conn.waiters.Store(1)
+ ap2.mu.Lock()
+ ap2.conns[conn.id] = conn
+ ap2.lastCleanupAt = time.Now()
+ ap2.mu.Unlock()
+ _, err = fullPool.Acquire(context.Background(), openAIWSAcquireRequest{
+ Account: account2,
+ WSURL: "wss://example.com/v1/responses",
+ })
+ require.ErrorIs(t, err, errOpenAIWSConnQueueFull)
+}
+
+type openAIWSFakeDialer struct{}
+
+func (d *openAIWSFakeDialer) Dial(
+ ctx context.Context,
+ wsURL string,
+ headers http.Header,
+ proxyURL string,
+) (openAIWSClientConn, int, http.Header, error) {
+ _ = ctx
+ _ = wsURL
+ _ = headers
+ _ = proxyURL
+ return &openAIWSFakeConn{}, 0, nil, nil
+}
+
+type openAIWSCountingDialer struct {
+ mu sync.Mutex
+ dialCount int
+}
+
+type openAIWSAlwaysFailDialer struct {
+ mu sync.Mutex
+ dialCount int
+}
+
+type openAIWSPingBlockingConn struct {
+ current *atomic.Int32
+ maxConcurrent *atomic.Int32
+ release <-chan struct{}
+}
+
+func (c *openAIWSPingBlockingConn) WriteJSON(context.Context, any) error {
+ return nil
+}
+
+func (c *openAIWSPingBlockingConn) ReadMessage(context.Context) ([]byte, error) {
+ return []byte(`{"type":"response.completed","response":{"id":"resp_blocking_ping"}}`), nil
+}
+
+func (c *openAIWSPingBlockingConn) Ping(ctx context.Context) error {
+ if c.current == nil || c.maxConcurrent == nil {
+ return nil
+ }
+
+ now := c.current.Add(1)
+ for {
+ prev := c.maxConcurrent.Load()
+ if now <= prev || c.maxConcurrent.CompareAndSwap(prev, now) {
+ break
+ }
+ }
+ defer c.current.Add(-1)
+
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ case <-c.release:
+ return nil
+ }
+}
+
+func (c *openAIWSPingBlockingConn) Close() error {
+ return nil
+}
+
+func (d *openAIWSCountingDialer) Dial(
+ ctx context.Context,
+ wsURL string,
+ headers http.Header,
+ proxyURL string,
+) (openAIWSClientConn, int, http.Header, error) {
+ _ = ctx
+ _ = wsURL
+ _ = headers
+ _ = proxyURL
+ d.mu.Lock()
+ d.dialCount++
+ d.mu.Unlock()
+ return &openAIWSFakeConn{}, 0, nil, nil
+}
+
+func (d *openAIWSCountingDialer) DialCount() int {
+ d.mu.Lock()
+ defer d.mu.Unlock()
+ return d.dialCount
+}
+
+func (d *openAIWSAlwaysFailDialer) Dial(
+ ctx context.Context,
+ wsURL string,
+ headers http.Header,
+ proxyURL string,
+) (openAIWSClientConn, int, http.Header, error) {
+ _ = ctx
+ _ = wsURL
+ _ = headers
+ _ = proxyURL
+ d.mu.Lock()
+ d.dialCount++
+ d.mu.Unlock()
+ return nil, 503, nil, errors.New("dial failed")
+}
+
+func (d *openAIWSAlwaysFailDialer) DialCount() int {
+ d.mu.Lock()
+ defer d.mu.Unlock()
+ return d.dialCount
+}
+
+type openAIWSFakeConn struct {
+ mu sync.Mutex
+ closed bool
+ payload [][]byte
+}
+
+func (c *openAIWSFakeConn) WriteJSON(ctx context.Context, value any) error {
+ _ = ctx
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ if c.closed {
+ return errors.New("closed")
+ }
+ c.payload = append(c.payload, []byte("ok"))
+ _ = value
+ return nil
+}
+
+func (c *openAIWSFakeConn) ReadMessage(ctx context.Context) ([]byte, error) {
+ _ = ctx
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ if c.closed {
+ return nil, errors.New("closed")
+ }
+ return []byte(`{"type":"response.completed","response":{"id":"resp_fake"}}`), nil
+}
+
+func (c *openAIWSFakeConn) Ping(ctx context.Context) error {
+ _ = ctx
+ return nil
+}
+
+func (c *openAIWSFakeConn) Close() error {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ c.closed = true
+ return nil
+}
+
+type openAIWSBlockingConn struct {
+ readDelay time.Duration
+}
+
+func (c *openAIWSBlockingConn) WriteJSON(ctx context.Context, value any) error {
+ _ = ctx
+ _ = value
+ return nil
+}
+
+func (c *openAIWSBlockingConn) ReadMessage(ctx context.Context) ([]byte, error) {
+ delay := c.readDelay
+ if delay <= 0 {
+ delay = 10 * time.Millisecond
+ }
+ timer := time.NewTimer(delay)
+ defer timer.Stop()
+
+ select {
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ case <-timer.C:
+ return []byte(`{"type":"response.completed","response":{"id":"resp_blocking"}}`), nil
+ }
+}
+
+func (c *openAIWSBlockingConn) Ping(ctx context.Context) error {
+ _ = ctx
+ return nil
+}
+
+func (c *openAIWSBlockingConn) Close() error {
+ return nil
+}
+
+type openAIWSWriteBlockingConn struct{}
+
+func (c *openAIWSWriteBlockingConn) WriteJSON(ctx context.Context, _ any) error {
+ <-ctx.Done()
+ return ctx.Err()
+}
+
+func (c *openAIWSWriteBlockingConn) ReadMessage(context.Context) ([]byte, error) {
+ return []byte(`{"type":"response.completed","response":{"id":"resp_write_block"}}`), nil
+}
+
+func (c *openAIWSWriteBlockingConn) Ping(context.Context) error {
+ return nil
+}
+
+func (c *openAIWSWriteBlockingConn) Close() error {
+ return nil
+}
+
+type openAIWSPingFailConn struct{}
+
+func (c *openAIWSPingFailConn) WriteJSON(context.Context, any) error {
+ return nil
+}
+
+func (c *openAIWSPingFailConn) ReadMessage(context.Context) ([]byte, error) {
+ return []byte(`{"type":"response.completed","response":{"id":"resp_ping_fail"}}`), nil
+}
+
+func (c *openAIWSPingFailConn) Ping(context.Context) error {
+ return errors.New("ping failed")
+}
+
+func (c *openAIWSPingFailConn) Close() error {
+ return nil
+}
+
+type openAIWSContextProbeConn struct {
+ lastWriteCtx context.Context
+}
+
+func (c *openAIWSContextProbeConn) WriteJSON(ctx context.Context, _ any) error {
+ c.lastWriteCtx = ctx
+ return nil
+}
+
+func (c *openAIWSContextProbeConn) ReadMessage(context.Context) ([]byte, error) {
+ return []byte(`{"type":"response.completed","response":{"id":"resp_ctx_probe"}}`), nil
+}
+
+func (c *openAIWSContextProbeConn) Ping(context.Context) error {
+ return nil
+}
+
+func (c *openAIWSContextProbeConn) Close() error {
+ return nil
+}
+
+type openAIWSNilConnDialer struct{}
+
+func (d *openAIWSNilConnDialer) Dial(
+ ctx context.Context,
+ wsURL string,
+ headers http.Header,
+ proxyURL string,
+) (openAIWSClientConn, int, http.Header, error) {
+ _ = ctx
+ _ = wsURL
+ _ = headers
+ _ = proxyURL
+ return nil, 200, nil, nil
+}
+
+func TestOpenAIWSConnPool_DialConnNilConnection(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2
+ cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 1
+
+ pool := newOpenAIWSConnPool(cfg)
+ pool.setClientDialerForTest(&openAIWSNilConnDialer{})
+ account := &Account{ID: 91, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
+
+ _, err := pool.Acquire(context.Background(), openAIWSAcquireRequest{
+ Account: account,
+ WSURL: "wss://example.com/v1/responses",
+ })
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "nil connection")
+}
+
+func TestOpenAIWSConnPool_SnapshotTransportMetrics(t *testing.T) {
+ cfg := &config.Config{}
+ pool := newOpenAIWSConnPool(cfg)
+
+ dialer, ok := pool.clientDialer.(*coderOpenAIWSClientDialer)
+ require.True(t, ok)
+
+ _, err := dialer.proxyHTTPClient("http://127.0.0.1:28080")
+ require.NoError(t, err)
+ _, err = dialer.proxyHTTPClient("http://127.0.0.1:28080")
+ require.NoError(t, err)
+ _, err = dialer.proxyHTTPClient("http://127.0.0.1:28081")
+ require.NoError(t, err)
+
+ snapshot := pool.SnapshotTransportMetrics()
+ require.Equal(t, int64(1), snapshot.ProxyClientCacheHits)
+ require.Equal(t, int64(2), snapshot.ProxyClientCacheMisses)
+ require.InDelta(t, 1.0/3.0, snapshot.TransportReuseRatio, 0.0001)
+}
diff --git a/backend/internal/service/openai_ws_protocol_forward_test.go b/backend/internal/service/openai_ws_protocol_forward_test.go
new file mode 100644
index 00000000..df4d4871
--- /dev/null
+++ b/backend/internal/service/openai_ws_protocol_forward_test.go
@@ -0,0 +1,1218 @@
+package service
+
+import (
+ "context"
+ "encoding/json"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/gin-gonic/gin"
+ "github.com/gorilla/websocket"
+ "github.com/stretchr/testify/require"
+ "github.com/tidwall/gjson"
+)
+
+func TestOpenAIGatewayService_Forward_PreservePreviousResponseIDWhenWSEnabled(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ wsFallbackServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ http.NotFound(w, r)
+ }))
+ defer wsFallbackServer.Close()
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
+ c.Request.Header.Set("User-Agent", "custom-client/1.0")
+
+ upstream := &httpUpstreamRecorder{
+ resp: &http.Response{
+ StatusCode: http.StatusOK,
+ Header: http.Header{"Content-Type": []string{"application/json"}},
+ Body: io.NopCloser(strings.NewReader(
+ `{"usage":{"input_tokens":1,"output_tokens":2,"input_tokens_details":{"cached_tokens":0}}}`,
+ )),
+ },
+ }
+
+ cfg := &config.Config{}
+ cfg.Security.URLAllowlist.Enabled = false
+ cfg.Security.URLAllowlist.AllowInsecureHTTP = true
+ cfg.Gateway.OpenAIWS.Enabled = true
+ cfg.Gateway.OpenAIWS.OAuthEnabled = true
+ cfg.Gateway.OpenAIWS.APIKeyEnabled = true
+ cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
+
+ svc := &OpenAIGatewayService{
+ cfg: cfg,
+ httpUpstream: upstream,
+ openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
+ }
+
+ account := &Account{
+ ID: 1,
+ Name: "openai-apikey",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "api_key": "sk-test",
+ "base_url": wsFallbackServer.URL,
+ },
+ Extra: map[string]any{
+ "responses_websockets_v2_enabled": true,
+ },
+ }
+
+ body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_123","input":[{"type":"input_text","text":"hello"}]}`)
+ result, err := svc.Forward(context.Background(), c, account, body)
+ require.Error(t, err)
+ require.Nil(t, result)
+ require.Nil(t, upstream.lastReq, "WS 模式下失败时不应回退 HTTP")
+}
+
+func TestOpenAIGatewayService_Forward_HTTPIngressStaysHTTPWhenWSEnabled(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ wsFallbackServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ http.NotFound(w, r)
+ }))
+ defer wsFallbackServer.Close()
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
+ c.Request.Header.Set("User-Agent", "custom-client/1.0")
+ SetOpenAIClientTransport(c, OpenAIClientTransportHTTP)
+
+ upstream := &httpUpstreamRecorder{
+ resp: &http.Response{
+ StatusCode: http.StatusOK,
+ Header: http.Header{"Content-Type": []string{"application/json"}},
+ Body: io.NopCloser(strings.NewReader(
+ `{"usage":{"input_tokens":1,"output_tokens":2,"input_tokens_details":{"cached_tokens":0}}}`,
+ )),
+ },
+ }
+
+ cfg := &config.Config{}
+ cfg.Security.URLAllowlist.Enabled = false
+ cfg.Security.URLAllowlist.AllowInsecureHTTP = true
+ cfg.Gateway.OpenAIWS.Enabled = true
+ cfg.Gateway.OpenAIWS.OAuthEnabled = true
+ cfg.Gateway.OpenAIWS.APIKeyEnabled = true
+ cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
+
+ svc := &OpenAIGatewayService{
+ cfg: cfg,
+ httpUpstream: upstream,
+ openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
+ }
+
+ account := &Account{
+ ID: 101,
+ Name: "openai-apikey",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "api_key": "sk-test",
+ "base_url": wsFallbackServer.URL,
+ },
+ Extra: map[string]any{
+ "responses_websockets_v2_enabled": true,
+ },
+ }
+
+ body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_http_keep","input":[{"type":"input_text","text":"hello"}]}`)
+ result, err := svc.Forward(context.Background(), c, account, body)
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.False(t, result.OpenAIWSMode, "HTTP 入站应保持 HTTP 转发")
+ require.NotNil(t, upstream.lastReq, "HTTP 入站应命中 HTTP 上游")
+ require.False(t, gjson.GetBytes(upstream.lastBody, "previous_response_id").Exists(), "HTTP 路径应沿用原逻辑移除 previous_response_id")
+
+ decision, _ := c.Get("openai_ws_transport_decision")
+ reason, _ := c.Get("openai_ws_transport_reason")
+ require.Equal(t, string(OpenAIUpstreamTransportHTTPSSE), decision)
+ require.Equal(t, "client_protocol_http", reason)
+}
+
+func TestOpenAIGatewayService_Forward_RemovePreviousResponseIDWhenWSDisabled(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ wsFallbackServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ http.NotFound(w, r)
+ }))
+ defer wsFallbackServer.Close()
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
+ c.Request.Header.Set("User-Agent", "custom-client/1.0")
+
+ upstream := &httpUpstreamRecorder{
+ resp: &http.Response{
+ StatusCode: http.StatusOK,
+ Header: http.Header{"Content-Type": []string{"application/json"}},
+ Body: io.NopCloser(strings.NewReader(
+ `{"usage":{"input_tokens":1,"output_tokens":2,"input_tokens_details":{"cached_tokens":0}}}`,
+ )),
+ },
+ }
+
+ cfg := &config.Config{}
+ cfg.Security.URLAllowlist.Enabled = false
+ cfg.Security.URLAllowlist.AllowInsecureHTTP = true
+ cfg.Gateway.OpenAIWS.Enabled = false
+ cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
+
+ svc := &OpenAIGatewayService{
+ cfg: cfg,
+ httpUpstream: upstream,
+ openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
+ }
+
+ account := &Account{
+ ID: 1,
+ Name: "openai-apikey",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "api_key": "sk-test",
+ "base_url": wsFallbackServer.URL,
+ },
+ Extra: map[string]any{
+ "responses_websockets_v2_enabled": true,
+ },
+ }
+
+ body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_123","input":[{"type":"input_text","text":"hello"}]}`)
+ result, err := svc.Forward(context.Background(), c, account, body)
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.False(t, gjson.GetBytes(upstream.lastBody, "previous_response_id").Exists())
+}
+
+func TestOpenAIGatewayService_Forward_WSv2Dial426FallbackHTTP(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ ws426Server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusUpgradeRequired)
+ _, _ = w.Write([]byte(`upgrade required`))
+ }))
+ defer ws426Server.Close()
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
+ c.Request.Header.Set("User-Agent", "custom-client/1.0")
+
+ upstream := &httpUpstreamRecorder{
+ resp: &http.Response{
+ StatusCode: http.StatusOK,
+ Header: http.Header{"Content-Type": []string{"application/json"}},
+ Body: io.NopCloser(strings.NewReader(
+ `{"usage":{"input_tokens":8,"output_tokens":9,"input_tokens_details":{"cached_tokens":1}}}`,
+ )),
+ },
+ }
+
+ cfg := &config.Config{}
+ cfg.Security.URLAllowlist.Enabled = false
+ cfg.Security.URLAllowlist.AllowInsecureHTTP = true
+ cfg.Gateway.OpenAIWS.Enabled = true
+ cfg.Gateway.OpenAIWS.OAuthEnabled = true
+ cfg.Gateway.OpenAIWS.APIKeyEnabled = true
+ cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
+ cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1
+
+ svc := &OpenAIGatewayService{
+ cfg: cfg,
+ httpUpstream: upstream,
+ openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
+ }
+
+ account := &Account{
+ ID: 12,
+ Name: "openai-apikey",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "api_key": "sk-test",
+ "base_url": ws426Server.URL,
+ },
+ Extra: map[string]any{
+ "responses_websockets_v2_enabled": true,
+ },
+ }
+
+ body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_426","input":[{"type":"input_text","text":"hello"}]}`)
+ result, err := svc.Forward(context.Background(), c, account, body)
+ require.Error(t, err)
+ require.Nil(t, result)
+ require.Contains(t, err.Error(), "upgrade_required")
+ require.Nil(t, upstream.lastReq, "WS 模式下不应再回退 HTTP")
+ require.Equal(t, http.StatusUpgradeRequired, rec.Code)
+ require.Contains(t, rec.Body.String(), "426")
+}
+
+func TestOpenAIGatewayService_Forward_WSv2FallbackCoolingSkipWS(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ http.NotFound(w, r)
+ }))
+ defer wsServer.Close()
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
+ c.Request.Header.Set("User-Agent", "custom-client/1.0")
+
+ upstream := &httpUpstreamRecorder{
+ resp: &http.Response{
+ StatusCode: http.StatusOK,
+ Header: http.Header{"Content-Type": []string{"application/json"}},
+ Body: io.NopCloser(strings.NewReader(
+ `{"usage":{"input_tokens":2,"output_tokens":3,"input_tokens_details":{"cached_tokens":0}}}`,
+ )),
+ },
+ }
+
+ cfg := &config.Config{}
+ cfg.Security.URLAllowlist.Enabled = false
+ cfg.Security.URLAllowlist.AllowInsecureHTTP = true
+ cfg.Gateway.OpenAIWS.Enabled = true
+ cfg.Gateway.OpenAIWS.OAuthEnabled = true
+ cfg.Gateway.OpenAIWS.APIKeyEnabled = true
+ cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
+ cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 30
+
+ svc := &OpenAIGatewayService{
+ cfg: cfg,
+ httpUpstream: upstream,
+ openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
+ }
+
+ account := &Account{
+ ID: 21,
+ Name: "openai-apikey",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "api_key": "sk-test",
+ "base_url": wsServer.URL,
+ },
+ Extra: map[string]any{
+ "responses_websockets_v2_enabled": true,
+ },
+ }
+
+ svc.markOpenAIWSFallbackCooling(account.ID, "upgrade_required")
+ body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_cooling","input":[{"type":"input_text","text":"hello"}]}`)
+ result, err := svc.Forward(context.Background(), c, account, body)
+ require.Error(t, err)
+ require.Nil(t, result)
+ require.Nil(t, upstream.lastReq, "WS 模式下不应再回退 HTTP")
+
+ _, ok := c.Get("openai_ws_fallback_cooling")
+ require.False(t, ok, "已移除 fallback cooling 快捷回退路径")
+}
+
+func TestOpenAIGatewayService_Forward_ReturnErrorWhenOnlyWSv1Enabled(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
+ c.Request.Header.Set("User-Agent", "custom-client/1.0")
+
+ upstream := &httpUpstreamRecorder{
+ resp: &http.Response{
+ StatusCode: http.StatusOK,
+ Header: http.Header{"Content-Type": []string{"application/json"}},
+ Body: io.NopCloser(strings.NewReader(
+ `{"usage":{"input_tokens":1,"output_tokens":2,"input_tokens_details":{"cached_tokens":0}}}`,
+ )),
+ },
+ }
+
+ cfg := &config.Config{}
+ cfg.Security.URLAllowlist.Enabled = false
+ cfg.Security.URLAllowlist.AllowInsecureHTTP = true
+ cfg.Gateway.OpenAIWS.Enabled = true
+ cfg.Gateway.OpenAIWS.OAuthEnabled = true
+ cfg.Gateway.OpenAIWS.APIKeyEnabled = true
+ cfg.Gateway.OpenAIWS.ResponsesWebsockets = true
+ cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = false
+
+ svc := &OpenAIGatewayService{
+ cfg: cfg,
+ httpUpstream: upstream,
+ openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
+ }
+
+ account := &Account{
+ ID: 31,
+ Name: "openai-apikey",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "api_key": "sk-test",
+ "base_url": "https://api.openai.com/v1/responses",
+ },
+ Extra: map[string]any{
+ "responses_websockets_v2_enabled": true,
+ },
+ }
+
+ body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_v1","input":[{"type":"input_text","text":"hello"}]}`)
+ result, err := svc.Forward(context.Background(), c, account, body)
+ require.Error(t, err)
+ require.Nil(t, result)
+ require.Contains(t, err.Error(), "ws v1")
+ require.Equal(t, http.StatusBadRequest, rec.Code)
+ require.Contains(t, rec.Body.String(), "WSv1")
+ require.Nil(t, upstream.lastReq, "WSv1 不支持时不应触发 HTTP 上游请求")
+}
+
+func TestNewOpenAIGatewayService_InitializesOpenAIWSResolver(t *testing.T) {
+ cfg := &config.Config{}
+ svc := NewOpenAIGatewayService(
+ nil,
+ nil,
+ nil,
+ nil,
+ nil,
+ cfg,
+ nil,
+ nil,
+ nil,
+ nil,
+ nil,
+ nil,
+ nil,
+ nil,
+ )
+
+ decision := svc.getOpenAIWSProtocolResolver().Resolve(nil)
+ require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport)
+ require.Equal(t, "account_missing", decision.Reason)
+}
+
+func TestOpenAIGatewayService_Forward_WSv2FallbackWhenResponseAlreadyWrittenReturnsWSError(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ ws426Server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusUpgradeRequired)
+ _, _ = w.Write([]byte(`upgrade required`))
+ }))
+ defer ws426Server.Close()
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
+ c.Request.Header.Set("User-Agent", "custom-client/1.0")
+ c.String(http.StatusAccepted, "already-written")
+
+ upstream := &httpUpstreamRecorder{
+ resp: &http.Response{
+ StatusCode: http.StatusOK,
+ Header: http.Header{"Content-Type": []string{"application/json"}},
+ Body: io.NopCloser(strings.NewReader(`{"usage":{"input_tokens":1,"output_tokens":1}}`)),
+ },
+ }
+
+ cfg := &config.Config{}
+ cfg.Security.URLAllowlist.Enabled = false
+ cfg.Security.URLAllowlist.AllowInsecureHTTP = true
+ cfg.Gateway.OpenAIWS.Enabled = true
+ cfg.Gateway.OpenAIWS.OAuthEnabled = true
+ cfg.Gateway.OpenAIWS.APIKeyEnabled = true
+ cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
+ cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1
+
+ svc := &OpenAIGatewayService{
+ cfg: cfg,
+ httpUpstream: upstream,
+ openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
+ }
+
+ account := &Account{
+ ID: 41,
+ Name: "openai-apikey",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "api_key": "sk-test",
+ "base_url": ws426Server.URL,
+ },
+ Extra: map[string]any{
+ "responses_websockets_v2_enabled": true,
+ },
+ }
+
+ body := []byte(`{"model":"gpt-5.1","stream":false,"input":[{"type":"input_text","text":"hello"}]}`)
+ result, err := svc.Forward(context.Background(), c, account, body)
+ require.Error(t, err)
+ require.Nil(t, result)
+ require.Contains(t, err.Error(), "ws fallback")
+ require.Nil(t, upstream.lastReq, "已写下游响应时,不应再回退 HTTP")
+}
+
+func TestOpenAIGatewayService_Forward_WSv2StreamEarlyCloseFallbackHTTP(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
+ wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ conn, err := upgrader.Upgrade(w, r, nil)
+ if err != nil {
+ t.Errorf("upgrade websocket failed: %v", err)
+ return
+ }
+ defer func() {
+ _ = conn.Close()
+ }()
+
+ var req map[string]any
+ if err := conn.ReadJSON(&req); err != nil {
+ t.Errorf("read ws request failed: %v", err)
+ return
+ }
+
+ // 仅发送 response.created(非 token 事件)后立即关闭,
+ // 模拟线上“上游早期内部错误断连”的场景。
+ if err := conn.WriteJSON(map[string]any{
+ "type": "response.created",
+ "response": map[string]any{
+ "id": "resp_ws_created_only",
+ "model": "gpt-5.3-codex",
+ },
+ }); err != nil {
+ t.Errorf("write response.created failed: %v", err)
+ return
+ }
+ closePayload := websocket.FormatCloseMessage(websocket.CloseInternalServerErr, "")
+ _ = conn.WriteControl(websocket.CloseMessage, closePayload, time.Now().Add(time.Second))
+ }))
+ defer wsServer.Close()
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
+ c.Request.Header.Set("User-Agent", "custom-client/1.0")
+
+ upstream := &httpUpstreamRecorder{
+ resp: &http.Response{
+ StatusCode: http.StatusOK,
+ Header: http.Header{"Content-Type": []string{"text/event-stream"}},
+ Body: io.NopCloser(strings.NewReader(
+ "data: {\"type\":\"response.output_text.delta\",\"delta\":\"ok\"}\n\n" +
+ "data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_http_fallback\",\"usage\":{\"input_tokens\":2,\"output_tokens\":1}}}\n\n" +
+ "data: [DONE]\n\n",
+ )),
+ },
+ }
+
+ cfg := &config.Config{}
+ cfg.Security.URLAllowlist.Enabled = false
+ cfg.Security.URLAllowlist.AllowInsecureHTTP = true
+ cfg.Gateway.OpenAIWS.Enabled = true
+ cfg.Gateway.OpenAIWS.OAuthEnabled = true
+ cfg.Gateway.OpenAIWS.APIKeyEnabled = true
+ cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
+ cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1
+
+ svc := &OpenAIGatewayService{
+ cfg: cfg,
+ httpUpstream: upstream,
+ openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
+ toolCorrector: NewCodexToolCorrector(),
+ }
+
+ account := &Account{
+ ID: 88,
+ Name: "openai-apikey",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "api_key": "sk-test",
+ "base_url": wsServer.URL,
+ },
+ Extra: map[string]any{
+ "responses_websockets_v2_enabled": true,
+ },
+ }
+
+ body := []byte(`{"model":"gpt-5.3-codex","stream":true,"input":[{"type":"input_text","text":"hello"}]}`)
+ result, err := svc.Forward(context.Background(), c, account, body)
+ require.Error(t, err)
+ require.Nil(t, result)
+ require.Nil(t, upstream.lastReq, "WS 早期断连后不应再回退 HTTP")
+ require.Empty(t, rec.Body.String(), "未产出 token 前上游断连时不应写入下游半截流")
+}
+
+func TestOpenAIGatewayService_Forward_WSv2RetryFiveTimesThenFallbackHTTP(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ var wsAttempts atomic.Int32
+ upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
+ wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ wsAttempts.Add(1)
+ conn, err := upgrader.Upgrade(w, r, nil)
+ if err != nil {
+ t.Errorf("upgrade websocket failed: %v", err)
+ return
+ }
+ defer func() {
+ _ = conn.Close()
+ }()
+
+ var req map[string]any
+ if err := conn.ReadJSON(&req); err != nil {
+ t.Errorf("read ws request failed: %v", err)
+ return
+ }
+ closePayload := websocket.FormatCloseMessage(websocket.CloseInternalServerErr, "")
+ _ = conn.WriteControl(websocket.CloseMessage, closePayload, time.Now().Add(time.Second))
+ }))
+ defer wsServer.Close()
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
+ c.Request.Header.Set("User-Agent", "custom-client/1.0")
+
+ upstream := &httpUpstreamRecorder{
+ resp: &http.Response{
+ StatusCode: http.StatusOK,
+ Header: http.Header{"Content-Type": []string{"text/event-stream"}},
+ Body: io.NopCloser(strings.NewReader(
+ "data: {\"type\":\"response.output_text.delta\",\"delta\":\"ok\"}\n\n" +
+ "data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_retry_http_fallback\",\"usage\":{\"input_tokens\":2,\"output_tokens\":1}}}\n\n" +
+ "data: [DONE]\n\n",
+ )),
+ },
+ }
+
+ cfg := &config.Config{}
+ cfg.Security.URLAllowlist.Enabled = false
+ cfg.Security.URLAllowlist.AllowInsecureHTTP = true
+ cfg.Gateway.OpenAIWS.Enabled = true
+ cfg.Gateway.OpenAIWS.OAuthEnabled = true
+ cfg.Gateway.OpenAIWS.APIKeyEnabled = true
+ cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
+ cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1
+
+ svc := &OpenAIGatewayService{
+ cfg: cfg,
+ httpUpstream: upstream,
+ openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
+ toolCorrector: NewCodexToolCorrector(),
+ }
+
+ account := &Account{
+ ID: 89,
+ Name: "openai-apikey",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "api_key": "sk-test",
+ "base_url": wsServer.URL,
+ },
+ Extra: map[string]any{
+ "responses_websockets_v2_enabled": true,
+ },
+ }
+
+ body := []byte(`{"model":"gpt-5.3-codex","stream":true,"input":[{"type":"input_text","text":"hello"}]}`)
+ result, err := svc.Forward(context.Background(), c, account, body)
+ require.Error(t, err)
+ require.Nil(t, result)
+ require.Nil(t, upstream.lastReq, "WS 重连耗尽后不应再回退 HTTP")
+ require.Equal(t, int32(openAIWSReconnectRetryLimit+1), wsAttempts.Load())
+}
+
+func TestOpenAIGatewayService_Forward_WSv2PolicyViolationFastFallbackHTTP(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ var wsAttempts atomic.Int32
+ upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
+ wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ wsAttempts.Add(1)
+ conn, err := upgrader.Upgrade(w, r, nil)
+ if err != nil {
+ t.Errorf("upgrade websocket failed: %v", err)
+ return
+ }
+ defer func() {
+ _ = conn.Close()
+ }()
+
+ var req map[string]any
+ if err := conn.ReadJSON(&req); err != nil {
+ t.Errorf("read ws request failed: %v", err)
+ return
+ }
+ closePayload := websocket.FormatCloseMessage(websocket.ClosePolicyViolation, "")
+ _ = conn.WriteControl(websocket.CloseMessage, closePayload, time.Now().Add(time.Second))
+ }))
+ defer wsServer.Close()
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
+ c.Request.Header.Set("User-Agent", "custom-client/1.0")
+
+ upstream := &httpUpstreamRecorder{
+ resp: &http.Response{
+ StatusCode: http.StatusOK,
+ Header: http.Header{"Content-Type": []string{"application/json"}},
+ Body: io.NopCloser(strings.NewReader(`{"id":"resp_policy_fallback","usage":{"input_tokens":1,"output_tokens":1}}`)),
+ },
+ }
+
+ cfg := &config.Config{}
+ cfg.Security.URLAllowlist.Enabled = false
+ cfg.Security.URLAllowlist.AllowInsecureHTTP = true
+ cfg.Gateway.OpenAIWS.Enabled = true
+ cfg.Gateway.OpenAIWS.OAuthEnabled = true
+ cfg.Gateway.OpenAIWS.APIKeyEnabled = true
+ cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
+ cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1
+ cfg.Gateway.OpenAIWS.RetryBackoffInitialMS = 1
+ cfg.Gateway.OpenAIWS.RetryBackoffMaxMS = 2
+ cfg.Gateway.OpenAIWS.RetryJitterRatio = 0
+
+ svc := &OpenAIGatewayService{
+ cfg: cfg,
+ httpUpstream: upstream,
+ openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
+ toolCorrector: NewCodexToolCorrector(),
+ }
+
+ account := &Account{
+ ID: 8901,
+ Name: "openai-apikey",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "api_key": "sk-test",
+ "base_url": wsServer.URL,
+ },
+ Extra: map[string]any{
+ "responses_websockets_v2_enabled": true,
+ },
+ }
+
+ body := []byte(`{"model":"gpt-5.3-codex","stream":false,"input":[{"type":"input_text","text":"hello"}]}`)
+ result, err := svc.Forward(context.Background(), c, account, body)
+ require.Error(t, err)
+ require.Nil(t, result)
+ require.Nil(t, upstream.lastReq, "策略违规关闭后不应回退 HTTP")
+ require.Equal(t, int32(1), wsAttempts.Load(), "策略违规不应进行 WS 重试")
+}
+
+func TestOpenAIGatewayService_Forward_WSv2ConnectionLimitReachedRetryThenFallbackHTTP(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ var wsAttempts atomic.Int32
+ upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
+ wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ wsAttempts.Add(1)
+ conn, err := upgrader.Upgrade(w, r, nil)
+ if err != nil {
+ t.Errorf("upgrade websocket failed: %v", err)
+ return
+ }
+ defer func() {
+ _ = conn.Close()
+ }()
+
+ var req map[string]any
+ if err := conn.ReadJSON(&req); err != nil {
+ t.Errorf("read ws request failed: %v", err)
+ return
+ }
+ _ = conn.WriteJSON(map[string]any{
+ "type": "error",
+ "error": map[string]any{
+ "code": "websocket_connection_limit_reached",
+ "type": "server_error",
+ "message": "websocket connection limit reached",
+ },
+ })
+ }))
+ defer wsServer.Close()
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
+ c.Request.Header.Set("User-Agent", "custom-client/1.0")
+
+ upstream := &httpUpstreamRecorder{
+ resp: &http.Response{
+ StatusCode: http.StatusOK,
+ Header: http.Header{"Content-Type": []string{"application/json"}},
+ Body: io.NopCloser(strings.NewReader(`{"id":"resp_http_retry_limit","usage":{"input_tokens":1,"output_tokens":1}}`)),
+ },
+ }
+
+ cfg := &config.Config{}
+ cfg.Security.URLAllowlist.Enabled = false
+ cfg.Security.URLAllowlist.AllowInsecureHTTP = true
+ cfg.Gateway.OpenAIWS.Enabled = true
+ cfg.Gateway.OpenAIWS.OAuthEnabled = true
+ cfg.Gateway.OpenAIWS.APIKeyEnabled = true
+ cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
+ cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1
+
+ svc := &OpenAIGatewayService{
+ cfg: cfg,
+ httpUpstream: upstream,
+ openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
+ toolCorrector: NewCodexToolCorrector(),
+ }
+
+ account := &Account{
+ ID: 90,
+ Name: "openai-apikey",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "api_key": "sk-test",
+ "base_url": wsServer.URL,
+ },
+ Extra: map[string]any{
+ "responses_websockets_v2_enabled": true,
+ },
+ }
+
+ body := []byte(`{"model":"gpt-5.3-codex","stream":false,"input":[{"type":"input_text","text":"hello"}]}`)
+ result, err := svc.Forward(context.Background(), c, account, body)
+ require.Error(t, err)
+ require.Nil(t, result)
+ require.Nil(t, upstream.lastReq, "触发 websocket_connection_limit_reached 后不应回退 HTTP")
+ require.Equal(t, int32(openAIWSReconnectRetryLimit+1), wsAttempts.Load())
+}
+
+func TestOpenAIGatewayService_Forward_WSv2PreviousResponseNotFoundRecoversByDroppingPreviousResponseID(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ var wsAttempts atomic.Int32
+ var wsRequestPayloads [][]byte
+ var wsRequestMu sync.Mutex
+ upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
+ wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ attempt := wsAttempts.Add(1)
+ conn, err := upgrader.Upgrade(w, r, nil)
+ if err != nil {
+ t.Errorf("upgrade websocket failed: %v", err)
+ return
+ }
+ defer func() {
+ _ = conn.Close()
+ }()
+
+ var req map[string]any
+ if err := conn.ReadJSON(&req); err != nil {
+ t.Errorf("read ws request failed: %v", err)
+ return
+ }
+ reqRaw, _ := json.Marshal(req)
+ wsRequestMu.Lock()
+ wsRequestPayloads = append(wsRequestPayloads, reqRaw)
+ wsRequestMu.Unlock()
+ if attempt == 1 {
+ _ = conn.WriteJSON(map[string]any{
+ "type": "error",
+ "error": map[string]any{
+ "code": "previous_response_not_found",
+ "type": "invalid_request_error",
+ "message": "previous response not found",
+ },
+ })
+ return
+ }
+ _ = conn.WriteJSON(map[string]any{
+ "type": "response.completed",
+ "response": map[string]any{
+ "id": "resp_ws_prev_recover_ok",
+ "model": "gpt-5.3-codex",
+ "usage": map[string]any{
+ "input_tokens": 1,
+ "output_tokens": 1,
+ "input_tokens_details": map[string]any{
+ "cached_tokens": 0,
+ },
+ },
+ },
+ })
+ }))
+ defer wsServer.Close()
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
+ c.Request.Header.Set("User-Agent", "custom-client/1.0")
+
+ upstream := &httpUpstreamRecorder{
+ resp: &http.Response{
+ StatusCode: http.StatusOK,
+ Header: http.Header{"Content-Type": []string{"application/json"}},
+ Body: io.NopCloser(strings.NewReader(`{"id":"resp_http_drop_prev","usage":{"input_tokens":1,"output_tokens":1}}`)),
+ },
+ }
+
+ cfg := &config.Config{}
+ cfg.Security.URLAllowlist.Enabled = false
+ cfg.Security.URLAllowlist.AllowInsecureHTTP = true
+ cfg.Gateway.OpenAIWS.Enabled = true
+ cfg.Gateway.OpenAIWS.OAuthEnabled = true
+ cfg.Gateway.OpenAIWS.APIKeyEnabled = true
+ cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
+ cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1
+
+ svc := &OpenAIGatewayService{
+ cfg: cfg,
+ httpUpstream: upstream,
+ openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
+ toolCorrector: NewCodexToolCorrector(),
+ }
+
+ account := &Account{
+ ID: 91,
+ Name: "openai-apikey",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "api_key": "sk-test",
+ "base_url": wsServer.URL,
+ },
+ Extra: map[string]any{
+ "responses_websockets_v2_enabled": true,
+ },
+ }
+
+ body := []byte(`{"model":"gpt-5.3-codex","stream":false,"previous_response_id":"resp_prev_missing","input":[{"type":"input_text","text":"hello"}]}`)
+ result, err := svc.Forward(context.Background(), c, account, body)
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.Equal(t, "resp_ws_prev_recover_ok", result.RequestID)
+ require.Nil(t, upstream.lastReq, "previous_response_not_found 不应回退 HTTP")
+ require.Equal(t, int32(2), wsAttempts.Load(), "previous_response_not_found 应触发一次去掉 previous_response_id 的恢复重试")
+ require.Equal(t, http.StatusOK, rec.Code)
+ require.Equal(t, "resp_ws_prev_recover_ok", gjson.Get(rec.Body.String(), "id").String())
+
+ wsRequestMu.Lock()
+ requests := append([][]byte(nil), wsRequestPayloads...)
+ wsRequestMu.Unlock()
+ require.Len(t, requests, 2)
+ require.True(t, gjson.GetBytes(requests[0], "previous_response_id").Exists(), "首轮请求应保留 previous_response_id")
+ require.False(t, gjson.GetBytes(requests[1], "previous_response_id").Exists(), "恢复重试应移除 previous_response_id")
+}
+
+func TestOpenAIGatewayService_Forward_WSv2PreviousResponseNotFoundSkipsRecoveryForFunctionCallOutput(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ var wsAttempts atomic.Int32
+ var wsRequestPayloads [][]byte
+ var wsRequestMu sync.Mutex
+ upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
+ wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ wsAttempts.Add(1)
+ conn, err := upgrader.Upgrade(w, r, nil)
+ if err != nil {
+ t.Errorf("upgrade websocket failed: %v", err)
+ return
+ }
+ defer func() {
+ _ = conn.Close()
+ }()
+
+ var req map[string]any
+ if err := conn.ReadJSON(&req); err != nil {
+ t.Errorf("read ws request failed: %v", err)
+ return
+ }
+ reqRaw, _ := json.Marshal(req)
+ wsRequestMu.Lock()
+ wsRequestPayloads = append(wsRequestPayloads, reqRaw)
+ wsRequestMu.Unlock()
+ _ = conn.WriteJSON(map[string]any{
+ "type": "error",
+ "error": map[string]any{
+ "code": "previous_response_not_found",
+ "type": "invalid_request_error",
+ "message": "previous response not found",
+ },
+ })
+ }))
+ defer wsServer.Close()
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
+ c.Request.Header.Set("User-Agent", "custom-client/1.0")
+
+ upstream := &httpUpstreamRecorder{
+ resp: &http.Response{
+ StatusCode: http.StatusOK,
+ Header: http.Header{"Content-Type": []string{"application/json"}},
+ Body: io.NopCloser(strings.NewReader(`{"id":"resp_http_drop_prev","usage":{"input_tokens":1,"output_tokens":1}}`)),
+ },
+ }
+
+ cfg := &config.Config{}
+ cfg.Security.URLAllowlist.Enabled = false
+ cfg.Security.URLAllowlist.AllowInsecureHTTP = true
+ cfg.Gateway.OpenAIWS.Enabled = true
+ cfg.Gateway.OpenAIWS.OAuthEnabled = true
+ cfg.Gateway.OpenAIWS.APIKeyEnabled = true
+ cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
+ cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1
+
+ svc := &OpenAIGatewayService{
+ cfg: cfg,
+ httpUpstream: upstream,
+ openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
+ toolCorrector: NewCodexToolCorrector(),
+ }
+
+ account := &Account{
+ ID: 92,
+ Name: "openai-apikey",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "api_key": "sk-test",
+ "base_url": wsServer.URL,
+ },
+ Extra: map[string]any{
+ "responses_websockets_v2_enabled": true,
+ },
+ }
+
+ body := []byte(`{"model":"gpt-5.3-codex","stream":false,"previous_response_id":"resp_prev_missing","input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]}`)
+ result, err := svc.Forward(context.Background(), c, account, body)
+ require.Error(t, err)
+ require.Nil(t, result)
+ require.Nil(t, upstream.lastReq, "previous_response_not_found 不应回退 HTTP")
+ require.Equal(t, int32(1), wsAttempts.Load(), "function_call_output 场景应跳过 previous_response_not_found 自动恢复")
+ require.Equal(t, http.StatusBadRequest, rec.Code)
+ require.Contains(t, strings.ToLower(rec.Body.String()), "previous response not found")
+
+ wsRequestMu.Lock()
+ requests := append([][]byte(nil), wsRequestPayloads...)
+ wsRequestMu.Unlock()
+ require.Len(t, requests, 1)
+ require.True(t, gjson.GetBytes(requests[0], "previous_response_id").Exists())
+}
+
+func TestOpenAIGatewayService_Forward_WSv2PreviousResponseNotFoundSkipsRecoveryWithoutPreviousResponseID(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ var wsAttempts atomic.Int32
+ var wsRequestPayloads [][]byte
+ var wsRequestMu sync.Mutex
+ upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
+ wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ wsAttempts.Add(1)
+ conn, err := upgrader.Upgrade(w, r, nil)
+ if err != nil {
+ t.Errorf("upgrade websocket failed: %v", err)
+ return
+ }
+ defer func() {
+ _ = conn.Close()
+ }()
+
+ var req map[string]any
+ if err := conn.ReadJSON(&req); err != nil {
+ t.Errorf("read ws request failed: %v", err)
+ return
+ }
+ reqRaw, _ := json.Marshal(req)
+ wsRequestMu.Lock()
+ wsRequestPayloads = append(wsRequestPayloads, reqRaw)
+ wsRequestMu.Unlock()
+ _ = conn.WriteJSON(map[string]any{
+ "type": "error",
+ "error": map[string]any{
+ "code": "previous_response_not_found",
+ "type": "invalid_request_error",
+ "message": "previous response not found",
+ },
+ })
+ }))
+ defer wsServer.Close()
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
+ c.Request.Header.Set("User-Agent", "custom-client/1.0")
+
+ upstream := &httpUpstreamRecorder{
+ resp: &http.Response{
+ StatusCode: http.StatusOK,
+ Header: http.Header{"Content-Type": []string{"application/json"}},
+ Body: io.NopCloser(strings.NewReader(`{"id":"resp_http_drop_prev","usage":{"input_tokens":1,"output_tokens":1}}`)),
+ },
+ }
+
+ cfg := &config.Config{}
+ cfg.Security.URLAllowlist.Enabled = false
+ cfg.Security.URLAllowlist.AllowInsecureHTTP = true
+ cfg.Gateway.OpenAIWS.Enabled = true
+ cfg.Gateway.OpenAIWS.OAuthEnabled = true
+ cfg.Gateway.OpenAIWS.APIKeyEnabled = true
+ cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
+ cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1
+
+ svc := &OpenAIGatewayService{
+ cfg: cfg,
+ httpUpstream: upstream,
+ openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
+ toolCorrector: NewCodexToolCorrector(),
+ }
+
+ account := &Account{
+ ID: 93,
+ Name: "openai-apikey",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "api_key": "sk-test",
+ "base_url": wsServer.URL,
+ },
+ Extra: map[string]any{
+ "responses_websockets_v2_enabled": true,
+ },
+ }
+
+ body := []byte(`{"model":"gpt-5.3-codex","stream":false,"input":[{"type":"input_text","text":"hello"}]}`)
+ result, err := svc.Forward(context.Background(), c, account, body)
+ require.Error(t, err)
+ require.Nil(t, result)
+ require.Nil(t, upstream.lastReq, "WS 模式下 previous_response_not_found 不应回退 HTTP")
+ require.Equal(t, int32(1), wsAttempts.Load(), "缺少 previous_response_id 时应跳过自动恢复重试")
+ require.Equal(t, http.StatusBadRequest, rec.Code)
+
+ wsRequestMu.Lock()
+ requests := append([][]byte(nil), wsRequestPayloads...)
+ wsRequestMu.Unlock()
+ require.Len(t, requests, 1)
+ require.False(t, gjson.GetBytes(requests[0], "previous_response_id").Exists())
+}
+
+func TestOpenAIGatewayService_Forward_WSv2PreviousResponseNotFoundOnlyRecoversOnce(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ var wsAttempts atomic.Int32
+ var wsRequestPayloads [][]byte
+ var wsRequestMu sync.Mutex
+ upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
+ wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ wsAttempts.Add(1)
+ conn, err := upgrader.Upgrade(w, r, nil)
+ if err != nil {
+ t.Errorf("upgrade websocket failed: %v", err)
+ return
+ }
+ defer func() {
+ _ = conn.Close()
+ }()
+
+ var req map[string]any
+ if err := conn.ReadJSON(&req); err != nil {
+ t.Errorf("read ws request failed: %v", err)
+ return
+ }
+ reqRaw, _ := json.Marshal(req)
+ wsRequestMu.Lock()
+ wsRequestPayloads = append(wsRequestPayloads, reqRaw)
+ wsRequestMu.Unlock()
+ _ = conn.WriteJSON(map[string]any{
+ "type": "error",
+ "error": map[string]any{
+ "code": "previous_response_not_found",
+ "type": "invalid_request_error",
+ "message": "previous response not found",
+ },
+ })
+ }))
+ defer wsServer.Close()
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
+ c.Request.Header.Set("User-Agent", "custom-client/1.0")
+
+ upstream := &httpUpstreamRecorder{
+ resp: &http.Response{
+ StatusCode: http.StatusOK,
+ Header: http.Header{"Content-Type": []string{"application/json"}},
+ Body: io.NopCloser(strings.NewReader(`{"id":"resp_http_drop_prev","usage":{"input_tokens":1,"output_tokens":1}}`)),
+ },
+ }
+
+ cfg := &config.Config{}
+ cfg.Security.URLAllowlist.Enabled = false
+ cfg.Security.URLAllowlist.AllowInsecureHTTP = true
+ cfg.Gateway.OpenAIWS.Enabled = true
+ cfg.Gateway.OpenAIWS.OAuthEnabled = true
+ cfg.Gateway.OpenAIWS.APIKeyEnabled = true
+ cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
+ cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1
+
+ svc := &OpenAIGatewayService{
+ cfg: cfg,
+ httpUpstream: upstream,
+ openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
+ toolCorrector: NewCodexToolCorrector(),
+ }
+
+ account := &Account{
+ ID: 94,
+ Name: "openai-apikey",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "api_key": "sk-test",
+ "base_url": wsServer.URL,
+ },
+ Extra: map[string]any{
+ "responses_websockets_v2_enabled": true,
+ },
+ }
+
+ body := []byte(`{"model":"gpt-5.3-codex","stream":false,"previous_response_id":"resp_prev_missing","input":[{"type":"input_text","text":"hello"}]}`)
+ result, err := svc.Forward(context.Background(), c, account, body)
+ require.Error(t, err)
+ require.Nil(t, result)
+ require.Nil(t, upstream.lastReq, "WS 模式下 previous_response_not_found 不应回退 HTTP")
+ require.Equal(t, int32(2), wsAttempts.Load(), "应只允许一次自动恢复重试")
+ require.Equal(t, http.StatusBadRequest, rec.Code)
+
+ wsRequestMu.Lock()
+ requests := append([][]byte(nil), wsRequestPayloads...)
+ wsRequestMu.Unlock()
+ require.Len(t, requests, 2)
+ require.True(t, gjson.GetBytes(requests[0], "previous_response_id").Exists(), "首轮请求应包含 previous_response_id")
+ require.False(t, gjson.GetBytes(requests[1], "previous_response_id").Exists(), "恢复重试应移除 previous_response_id")
+}
diff --git a/backend/internal/service/openai_ws_protocol_resolver.go b/backend/internal/service/openai_ws_protocol_resolver.go
new file mode 100644
index 00000000..368643be
--- /dev/null
+++ b/backend/internal/service/openai_ws_protocol_resolver.go
@@ -0,0 +1,117 @@
+package service
+
+import "github.com/Wei-Shaw/sub2api/internal/config"
+
+// OpenAIUpstreamTransport 表示 OpenAI 上游传输协议。
+type OpenAIUpstreamTransport string
+
+const (
+ OpenAIUpstreamTransportAny OpenAIUpstreamTransport = ""
+ OpenAIUpstreamTransportHTTPSSE OpenAIUpstreamTransport = "http_sse"
+ OpenAIUpstreamTransportResponsesWebsocket OpenAIUpstreamTransport = "responses_websockets"
+ OpenAIUpstreamTransportResponsesWebsocketV2 OpenAIUpstreamTransport = "responses_websockets_v2"
+)
+
+// OpenAIWSProtocolDecision 表示协议决策结果。
+type OpenAIWSProtocolDecision struct {
+ Transport OpenAIUpstreamTransport
+ Reason string
+}
+
+// OpenAIWSProtocolResolver 定义 OpenAI 上游协议决策。
+type OpenAIWSProtocolResolver interface {
+ Resolve(account *Account) OpenAIWSProtocolDecision
+}
+
+type defaultOpenAIWSProtocolResolver struct {
+ cfg *config.Config
+}
+
+// NewOpenAIWSProtocolResolver 创建默认协议决策器。
+func NewOpenAIWSProtocolResolver(cfg *config.Config) OpenAIWSProtocolResolver {
+ return &defaultOpenAIWSProtocolResolver{cfg: cfg}
+}
+
+func (r *defaultOpenAIWSProtocolResolver) Resolve(account *Account) OpenAIWSProtocolDecision {
+ if account == nil {
+ return openAIWSHTTPDecision("account_missing")
+ }
+ if !account.IsOpenAI() {
+ return openAIWSHTTPDecision("platform_not_openai")
+ }
+ if account.IsOpenAIWSForceHTTPEnabled() {
+ return openAIWSHTTPDecision("account_force_http")
+ }
+ if r == nil || r.cfg == nil {
+ return openAIWSHTTPDecision("config_missing")
+ }
+
+ wsCfg := r.cfg.Gateway.OpenAIWS
+ if wsCfg.ForceHTTP {
+ return openAIWSHTTPDecision("global_force_http")
+ }
+ if !wsCfg.Enabled {
+ return openAIWSHTTPDecision("global_disabled")
+ }
+ if account.IsOpenAIOAuth() {
+ if !wsCfg.OAuthEnabled {
+ return openAIWSHTTPDecision("oauth_disabled")
+ }
+ } else if account.IsOpenAIApiKey() {
+ if !wsCfg.APIKeyEnabled {
+ return openAIWSHTTPDecision("apikey_disabled")
+ }
+ } else {
+ return openAIWSHTTPDecision("unknown_auth_type")
+ }
+ if wsCfg.ModeRouterV2Enabled {
+ mode := account.ResolveOpenAIResponsesWebSocketV2Mode(wsCfg.IngressModeDefault)
+ switch mode {
+ case OpenAIWSIngressModeOff:
+ return openAIWSHTTPDecision("account_mode_off")
+ case OpenAIWSIngressModeShared, OpenAIWSIngressModeDedicated:
+ // continue
+ default:
+ return openAIWSHTTPDecision("account_mode_off")
+ }
+ if account.Concurrency <= 0 {
+ return openAIWSHTTPDecision("account_concurrency_invalid")
+ }
+ if wsCfg.ResponsesWebsocketsV2 {
+ return OpenAIWSProtocolDecision{
+ Transport: OpenAIUpstreamTransportResponsesWebsocketV2,
+ Reason: "ws_v2_mode_" + mode,
+ }
+ }
+ if wsCfg.ResponsesWebsockets {
+ return OpenAIWSProtocolDecision{
+ Transport: OpenAIUpstreamTransportResponsesWebsocket,
+ Reason: "ws_v1_mode_" + mode,
+ }
+ }
+ return openAIWSHTTPDecision("feature_disabled")
+ }
+ if !account.IsOpenAIResponsesWebSocketV2Enabled() {
+ return openAIWSHTTPDecision("account_disabled")
+ }
+ if wsCfg.ResponsesWebsocketsV2 {
+ return OpenAIWSProtocolDecision{
+ Transport: OpenAIUpstreamTransportResponsesWebsocketV2,
+ Reason: "ws_v2_enabled",
+ }
+ }
+ if wsCfg.ResponsesWebsockets {
+ return OpenAIWSProtocolDecision{
+ Transport: OpenAIUpstreamTransportResponsesWebsocket,
+ Reason: "ws_v1_enabled",
+ }
+ }
+ return openAIWSHTTPDecision("feature_disabled")
+}
+
+func openAIWSHTTPDecision(reason string) OpenAIWSProtocolDecision {
+ return OpenAIWSProtocolDecision{
+ Transport: OpenAIUpstreamTransportHTTPSSE,
+ Reason: reason,
+ }
+}
diff --git a/backend/internal/service/openai_ws_protocol_resolver_test.go b/backend/internal/service/openai_ws_protocol_resolver_test.go
new file mode 100644
index 00000000..5be76e28
--- /dev/null
+++ b/backend/internal/service/openai_ws_protocol_resolver_test.go
@@ -0,0 +1,203 @@
+package service
+
+import (
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/stretchr/testify/require"
+)
+
+func TestOpenAIWSProtocolResolver_Resolve(t *testing.T) {
+ baseCfg := &config.Config{}
+ baseCfg.Gateway.OpenAIWS.Enabled = true
+ baseCfg.Gateway.OpenAIWS.OAuthEnabled = true
+ baseCfg.Gateway.OpenAIWS.APIKeyEnabled = true
+ baseCfg.Gateway.OpenAIWS.ResponsesWebsockets = false
+ baseCfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
+
+ openAIOAuthEnabled := &Account{
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Extra: map[string]any{
+ "openai_oauth_responses_websockets_v2_enabled": true,
+ },
+ }
+
+ t.Run("v2优先", func(t *testing.T) {
+ decision := NewOpenAIWSProtocolResolver(baseCfg).Resolve(openAIOAuthEnabled)
+ require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport)
+ require.Equal(t, "ws_v2_enabled", decision.Reason)
+ })
+
+ t.Run("v2关闭时回退v1", func(t *testing.T) {
+ cfg := *baseCfg
+ cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = false
+ cfg.Gateway.OpenAIWS.ResponsesWebsockets = true
+
+ decision := NewOpenAIWSProtocolResolver(&cfg).Resolve(openAIOAuthEnabled)
+ require.Equal(t, OpenAIUpstreamTransportResponsesWebsocket, decision.Transport)
+ require.Equal(t, "ws_v1_enabled", decision.Reason)
+ })
+
+ t.Run("透传开关不影响WS协议判定", func(t *testing.T) {
+ account := *openAIOAuthEnabled
+ account.Extra = map[string]any{
+ "openai_oauth_responses_websockets_v2_enabled": true,
+ "openai_passthrough": true,
+ }
+ decision := NewOpenAIWSProtocolResolver(baseCfg).Resolve(&account)
+ require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport)
+ require.Equal(t, "ws_v2_enabled", decision.Reason)
+ })
+
+ t.Run("账号级强制HTTP", func(t *testing.T) {
+ account := *openAIOAuthEnabled
+ account.Extra = map[string]any{
+ "openai_oauth_responses_websockets_v2_enabled": true,
+ "openai_ws_force_http": true,
+ }
+ decision := NewOpenAIWSProtocolResolver(baseCfg).Resolve(&account)
+ require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport)
+ require.Equal(t, "account_force_http", decision.Reason)
+ })
+
+ t.Run("全局关闭保持HTTP", func(t *testing.T) {
+ cfg := *baseCfg
+ cfg.Gateway.OpenAIWS.Enabled = false
+ decision := NewOpenAIWSProtocolResolver(&cfg).Resolve(openAIOAuthEnabled)
+ require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport)
+ require.Equal(t, "global_disabled", decision.Reason)
+ })
+
+ t.Run("账号开关关闭保持HTTP", func(t *testing.T) {
+ account := *openAIOAuthEnabled
+ account.Extra = map[string]any{
+ "openai_oauth_responses_websockets_v2_enabled": false,
+ }
+ decision := NewOpenAIWSProtocolResolver(baseCfg).Resolve(&account)
+ require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport)
+ require.Equal(t, "account_disabled", decision.Reason)
+ })
+
+ t.Run("OAuth账号不会读取API Key专用开关", func(t *testing.T) {
+ account := *openAIOAuthEnabled
+ account.Extra = map[string]any{
+ "openai_apikey_responses_websockets_v2_enabled": true,
+ }
+ decision := NewOpenAIWSProtocolResolver(baseCfg).Resolve(&account)
+ require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport)
+ require.Equal(t, "account_disabled", decision.Reason)
+ })
+
+ t.Run("兼容旧键openai_ws_enabled", func(t *testing.T) {
+ account := *openAIOAuthEnabled
+ account.Extra = map[string]any{
+ "openai_ws_enabled": true,
+ }
+ decision := NewOpenAIWSProtocolResolver(baseCfg).Resolve(&account)
+ require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport)
+ require.Equal(t, "ws_v2_enabled", decision.Reason)
+ })
+
+ t.Run("按账号类型开关控制", func(t *testing.T) {
+ cfg := *baseCfg
+ cfg.Gateway.OpenAIWS.OAuthEnabled = false
+ decision := NewOpenAIWSProtocolResolver(&cfg).Resolve(openAIOAuthEnabled)
+ require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport)
+ require.Equal(t, "oauth_disabled", decision.Reason)
+ })
+
+ t.Run("API Key 账号关闭开关时回退HTTP", func(t *testing.T) {
+ cfg := *baseCfg
+ cfg.Gateway.OpenAIWS.APIKeyEnabled = false
+ account := &Account{
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Extra: map[string]any{
+ "openai_apikey_responses_websockets_v2_enabled": true,
+ },
+ }
+ decision := NewOpenAIWSProtocolResolver(&cfg).Resolve(account)
+ require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport)
+ require.Equal(t, "apikey_disabled", decision.Reason)
+ })
+
+ t.Run("未知认证类型回退HTTP", func(t *testing.T) {
+ account := &Account{
+ Platform: PlatformOpenAI,
+ Type: "unknown_type",
+ Extra: map[string]any{
+ "responses_websockets_v2_enabled": true,
+ },
+ }
+ decision := NewOpenAIWSProtocolResolver(baseCfg).Resolve(account)
+ require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport)
+ require.Equal(t, "unknown_auth_type", decision.Reason)
+ })
+}
+
+func TestOpenAIWSProtocolResolver_Resolve_ModeRouterV2(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.Enabled = true
+ cfg.Gateway.OpenAIWS.OAuthEnabled = true
+ cfg.Gateway.OpenAIWS.APIKeyEnabled = true
+ cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
+ cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true
+ cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeShared
+
+ account := &Account{
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Concurrency: 1,
+ Extra: map[string]any{
+ "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated,
+ },
+ }
+
+ t.Run("dedicated mode routes to ws v2", func(t *testing.T) {
+ decision := NewOpenAIWSProtocolResolver(cfg).Resolve(account)
+ require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport)
+ require.Equal(t, "ws_v2_mode_dedicated", decision.Reason)
+ })
+
+ t.Run("off mode routes to http", func(t *testing.T) {
+ offAccount := &Account{
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Concurrency: 1,
+ Extra: map[string]any{
+ "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeOff,
+ },
+ }
+ decision := NewOpenAIWSProtocolResolver(cfg).Resolve(offAccount)
+ require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport)
+ require.Equal(t, "account_mode_off", decision.Reason)
+ })
+
+ t.Run("legacy boolean maps to shared in v2 router", func(t *testing.T) {
+ legacyAccount := &Account{
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Concurrency: 1,
+ Extra: map[string]any{
+ "openai_apikey_responses_websockets_v2_enabled": true,
+ },
+ }
+ decision := NewOpenAIWSProtocolResolver(cfg).Resolve(legacyAccount)
+ require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport)
+ require.Equal(t, "ws_v2_mode_shared", decision.Reason)
+ })
+
+ t.Run("non-positive concurrency is rejected in v2 router", func(t *testing.T) {
+ invalidConcurrency := &Account{
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Extra: map[string]any{
+ "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeShared,
+ },
+ }
+ decision := NewOpenAIWSProtocolResolver(cfg).Resolve(invalidConcurrency)
+ require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport)
+ require.Equal(t, "account_concurrency_invalid", decision.Reason)
+ })
+}
diff --git a/backend/internal/service/openai_ws_state_store.go b/backend/internal/service/openai_ws_state_store.go
new file mode 100644
index 00000000..b606baa1
--- /dev/null
+++ b/backend/internal/service/openai_ws_state_store.go
@@ -0,0 +1,440 @@
+package service
+
+import (
+ "context"
+ "crypto/sha256"
+ "encoding/hex"
+ "fmt"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "time"
+)
+
+const (
+ openAIWSResponseAccountCachePrefix = "openai:response:"
+ openAIWSStateStoreCleanupInterval = time.Minute
+ openAIWSStateStoreCleanupMaxPerMap = 512
+ openAIWSStateStoreMaxEntriesPerMap = 65536
+ openAIWSStateStoreRedisTimeout = 3 * time.Second
+)
+
+type openAIWSAccountBinding struct {
+ accountID int64
+ expiresAt time.Time
+}
+
+type openAIWSConnBinding struct {
+ connID string
+ expiresAt time.Time
+}
+
+type openAIWSTurnStateBinding struct {
+ turnState string
+ expiresAt time.Time
+}
+
+type openAIWSSessionConnBinding struct {
+ connID string
+ expiresAt time.Time
+}
+
+// OpenAIWSStateStore 管理 WSv2 的粘连状态。
+// - response_id -> account_id 用于续链路由
+// - response_id -> conn_id 用于连接内上下文复用
+//
+// response_id -> account_id 优先走 GatewayCache(Redis),同时维护本地热缓存。
+// response_id -> conn_id 仅在本进程内有效。
+type OpenAIWSStateStore interface {
+ BindResponseAccount(ctx context.Context, groupID int64, responseID string, accountID int64, ttl time.Duration) error
+ GetResponseAccount(ctx context.Context, groupID int64, responseID string) (int64, error)
+ DeleteResponseAccount(ctx context.Context, groupID int64, responseID string) error
+
+ BindResponseConn(responseID, connID string, ttl time.Duration)
+ GetResponseConn(responseID string) (string, bool)
+ DeleteResponseConn(responseID string)
+
+ BindSessionTurnState(groupID int64, sessionHash, turnState string, ttl time.Duration)
+ GetSessionTurnState(groupID int64, sessionHash string) (string, bool)
+ DeleteSessionTurnState(groupID int64, sessionHash string)
+
+ BindSessionConn(groupID int64, sessionHash, connID string, ttl time.Duration)
+ GetSessionConn(groupID int64, sessionHash string) (string, bool)
+ DeleteSessionConn(groupID int64, sessionHash string)
+}
+
+type defaultOpenAIWSStateStore struct {
+ cache GatewayCache
+
+ responseToAccountMu sync.RWMutex
+ responseToAccount map[string]openAIWSAccountBinding
+ responseToConnMu sync.RWMutex
+ responseToConn map[string]openAIWSConnBinding
+ sessionToTurnStateMu sync.RWMutex
+ sessionToTurnState map[string]openAIWSTurnStateBinding
+ sessionToConnMu sync.RWMutex
+ sessionToConn map[string]openAIWSSessionConnBinding
+
+ lastCleanupUnixNano atomic.Int64
+}
+
+// NewOpenAIWSStateStore 创建默认 WS 状态存储。
+func NewOpenAIWSStateStore(cache GatewayCache) OpenAIWSStateStore {
+ store := &defaultOpenAIWSStateStore{
+ cache: cache,
+ responseToAccount: make(map[string]openAIWSAccountBinding, 256),
+ responseToConn: make(map[string]openAIWSConnBinding, 256),
+ sessionToTurnState: make(map[string]openAIWSTurnStateBinding, 256),
+ sessionToConn: make(map[string]openAIWSSessionConnBinding, 256),
+ }
+ store.lastCleanupUnixNano.Store(time.Now().UnixNano())
+ return store
+}
+
+func (s *defaultOpenAIWSStateStore) BindResponseAccount(ctx context.Context, groupID int64, responseID string, accountID int64, ttl time.Duration) error {
+ id := normalizeOpenAIWSResponseID(responseID)
+ if id == "" || accountID <= 0 {
+ return nil
+ }
+ ttl = normalizeOpenAIWSTTL(ttl)
+ s.maybeCleanup()
+
+ expiresAt := time.Now().Add(ttl)
+ s.responseToAccountMu.Lock()
+ ensureBindingCapacity(s.responseToAccount, id, openAIWSStateStoreMaxEntriesPerMap)
+ s.responseToAccount[id] = openAIWSAccountBinding{accountID: accountID, expiresAt: expiresAt}
+ s.responseToAccountMu.Unlock()
+
+ if s.cache == nil {
+ return nil
+ }
+ cacheKey := openAIWSResponseAccountCacheKey(id)
+ cacheCtx, cancel := withOpenAIWSStateStoreRedisTimeout(ctx)
+ defer cancel()
+ return s.cache.SetSessionAccountID(cacheCtx, groupID, cacheKey, accountID, ttl)
+}
+
+func (s *defaultOpenAIWSStateStore) GetResponseAccount(ctx context.Context, groupID int64, responseID string) (int64, error) {
+ id := normalizeOpenAIWSResponseID(responseID)
+ if id == "" {
+ return 0, nil
+ }
+ s.maybeCleanup()
+
+ now := time.Now()
+ s.responseToAccountMu.RLock()
+ if binding, ok := s.responseToAccount[id]; ok {
+ if now.Before(binding.expiresAt) {
+ accountID := binding.accountID
+ s.responseToAccountMu.RUnlock()
+ return accountID, nil
+ }
+ }
+ s.responseToAccountMu.RUnlock()
+
+ if s.cache == nil {
+ return 0, nil
+ }
+
+ cacheKey := openAIWSResponseAccountCacheKey(id)
+ cacheCtx, cancel := withOpenAIWSStateStoreRedisTimeout(ctx)
+ defer cancel()
+ accountID, err := s.cache.GetSessionAccountID(cacheCtx, groupID, cacheKey)
+ if err != nil || accountID <= 0 {
+ // 缓存读取失败不阻断主流程,按未命中降级。
+ return 0, nil
+ }
+ return accountID, nil
+}
+
+func (s *defaultOpenAIWSStateStore) DeleteResponseAccount(ctx context.Context, groupID int64, responseID string) error {
+ id := normalizeOpenAIWSResponseID(responseID)
+ if id == "" {
+ return nil
+ }
+ s.responseToAccountMu.Lock()
+ delete(s.responseToAccount, id)
+ s.responseToAccountMu.Unlock()
+
+ if s.cache == nil {
+ return nil
+ }
+ cacheCtx, cancel := withOpenAIWSStateStoreRedisTimeout(ctx)
+ defer cancel()
+ return s.cache.DeleteSessionAccountID(cacheCtx, groupID, openAIWSResponseAccountCacheKey(id))
+}
+
+func (s *defaultOpenAIWSStateStore) BindResponseConn(responseID, connID string, ttl time.Duration) {
+ id := normalizeOpenAIWSResponseID(responseID)
+ conn := strings.TrimSpace(connID)
+ if id == "" || conn == "" {
+ return
+ }
+ ttl = normalizeOpenAIWSTTL(ttl)
+ s.maybeCleanup()
+
+ s.responseToConnMu.Lock()
+ ensureBindingCapacity(s.responseToConn, id, openAIWSStateStoreMaxEntriesPerMap)
+ s.responseToConn[id] = openAIWSConnBinding{
+ connID: conn,
+ expiresAt: time.Now().Add(ttl),
+ }
+ s.responseToConnMu.Unlock()
+}
+
+func (s *defaultOpenAIWSStateStore) GetResponseConn(responseID string) (string, bool) {
+ id := normalizeOpenAIWSResponseID(responseID)
+ if id == "" {
+ return "", false
+ }
+ s.maybeCleanup()
+
+ now := time.Now()
+ s.responseToConnMu.RLock()
+ binding, ok := s.responseToConn[id]
+ s.responseToConnMu.RUnlock()
+ if !ok || now.After(binding.expiresAt) || strings.TrimSpace(binding.connID) == "" {
+ return "", false
+ }
+ return binding.connID, true
+}
+
+func (s *defaultOpenAIWSStateStore) DeleteResponseConn(responseID string) {
+ id := normalizeOpenAIWSResponseID(responseID)
+ if id == "" {
+ return
+ }
+ s.responseToConnMu.Lock()
+ delete(s.responseToConn, id)
+ s.responseToConnMu.Unlock()
+}
+
+func (s *defaultOpenAIWSStateStore) BindSessionTurnState(groupID int64, sessionHash, turnState string, ttl time.Duration) {
+ key := openAIWSSessionTurnStateKey(groupID, sessionHash)
+ state := strings.TrimSpace(turnState)
+ if key == "" || state == "" {
+ return
+ }
+ ttl = normalizeOpenAIWSTTL(ttl)
+ s.maybeCleanup()
+
+ s.sessionToTurnStateMu.Lock()
+ ensureBindingCapacity(s.sessionToTurnState, key, openAIWSStateStoreMaxEntriesPerMap)
+ s.sessionToTurnState[key] = openAIWSTurnStateBinding{
+ turnState: state,
+ expiresAt: time.Now().Add(ttl),
+ }
+ s.sessionToTurnStateMu.Unlock()
+}
+
+func (s *defaultOpenAIWSStateStore) GetSessionTurnState(groupID int64, sessionHash string) (string, bool) {
+ key := openAIWSSessionTurnStateKey(groupID, sessionHash)
+ if key == "" {
+ return "", false
+ }
+ s.maybeCleanup()
+
+ now := time.Now()
+ s.sessionToTurnStateMu.RLock()
+ binding, ok := s.sessionToTurnState[key]
+ s.sessionToTurnStateMu.RUnlock()
+ if !ok || now.After(binding.expiresAt) || strings.TrimSpace(binding.turnState) == "" {
+ return "", false
+ }
+ return binding.turnState, true
+}
+
+func (s *defaultOpenAIWSStateStore) DeleteSessionTurnState(groupID int64, sessionHash string) {
+ key := openAIWSSessionTurnStateKey(groupID, sessionHash)
+ if key == "" {
+ return
+ }
+ s.sessionToTurnStateMu.Lock()
+ delete(s.sessionToTurnState, key)
+ s.sessionToTurnStateMu.Unlock()
+}
+
+func (s *defaultOpenAIWSStateStore) BindSessionConn(groupID int64, sessionHash, connID string, ttl time.Duration) {
+ key := openAIWSSessionTurnStateKey(groupID, sessionHash)
+ conn := strings.TrimSpace(connID)
+ if key == "" || conn == "" {
+ return
+ }
+ ttl = normalizeOpenAIWSTTL(ttl)
+ s.maybeCleanup()
+
+ s.sessionToConnMu.Lock()
+ ensureBindingCapacity(s.sessionToConn, key, openAIWSStateStoreMaxEntriesPerMap)
+ s.sessionToConn[key] = openAIWSSessionConnBinding{
+ connID: conn,
+ expiresAt: time.Now().Add(ttl),
+ }
+ s.sessionToConnMu.Unlock()
+}
+
+func (s *defaultOpenAIWSStateStore) GetSessionConn(groupID int64, sessionHash string) (string, bool) {
+ key := openAIWSSessionTurnStateKey(groupID, sessionHash)
+ if key == "" {
+ return "", false
+ }
+ s.maybeCleanup()
+
+ now := time.Now()
+ s.sessionToConnMu.RLock()
+ binding, ok := s.sessionToConn[key]
+ s.sessionToConnMu.RUnlock()
+ if !ok || now.After(binding.expiresAt) || strings.TrimSpace(binding.connID) == "" {
+ return "", false
+ }
+ return binding.connID, true
+}
+
+func (s *defaultOpenAIWSStateStore) DeleteSessionConn(groupID int64, sessionHash string) {
+ key := openAIWSSessionTurnStateKey(groupID, sessionHash)
+ if key == "" {
+ return
+ }
+ s.sessionToConnMu.Lock()
+ delete(s.sessionToConn, key)
+ s.sessionToConnMu.Unlock()
+}
+
+func (s *defaultOpenAIWSStateStore) maybeCleanup() {
+ if s == nil {
+ return
+ }
+ now := time.Now()
+ last := time.Unix(0, s.lastCleanupUnixNano.Load())
+ if now.Sub(last) < openAIWSStateStoreCleanupInterval {
+ return
+ }
+ if !s.lastCleanupUnixNano.CompareAndSwap(last.UnixNano(), now.UnixNano()) {
+ return
+ }
+
+ // 增量限额清理,避免高规模下一次性全量扫描导致长时间阻塞。
+ s.responseToAccountMu.Lock()
+ cleanupExpiredAccountBindings(s.responseToAccount, now, openAIWSStateStoreCleanupMaxPerMap)
+ s.responseToAccountMu.Unlock()
+
+ s.responseToConnMu.Lock()
+ cleanupExpiredConnBindings(s.responseToConn, now, openAIWSStateStoreCleanupMaxPerMap)
+ s.responseToConnMu.Unlock()
+
+ s.sessionToTurnStateMu.Lock()
+ cleanupExpiredTurnStateBindings(s.sessionToTurnState, now, openAIWSStateStoreCleanupMaxPerMap)
+ s.sessionToTurnStateMu.Unlock()
+
+ s.sessionToConnMu.Lock()
+ cleanupExpiredSessionConnBindings(s.sessionToConn, now, openAIWSStateStoreCleanupMaxPerMap)
+ s.sessionToConnMu.Unlock()
+}
+
+func cleanupExpiredAccountBindings(bindings map[string]openAIWSAccountBinding, now time.Time, maxScan int) {
+ if len(bindings) == 0 || maxScan <= 0 {
+ return
+ }
+ scanned := 0
+ for key, binding := range bindings {
+ if now.After(binding.expiresAt) {
+ delete(bindings, key)
+ }
+ scanned++
+ if scanned >= maxScan {
+ break
+ }
+ }
+}
+
+func cleanupExpiredConnBindings(bindings map[string]openAIWSConnBinding, now time.Time, maxScan int) {
+ if len(bindings) == 0 || maxScan <= 0 {
+ return
+ }
+ scanned := 0
+ for key, binding := range bindings {
+ if now.After(binding.expiresAt) {
+ delete(bindings, key)
+ }
+ scanned++
+ if scanned >= maxScan {
+ break
+ }
+ }
+}
+
+func cleanupExpiredTurnStateBindings(bindings map[string]openAIWSTurnStateBinding, now time.Time, maxScan int) {
+ if len(bindings) == 0 || maxScan <= 0 {
+ return
+ }
+ scanned := 0
+ for key, binding := range bindings {
+ if now.After(binding.expiresAt) {
+ delete(bindings, key)
+ }
+ scanned++
+ if scanned >= maxScan {
+ break
+ }
+ }
+}
+
+func cleanupExpiredSessionConnBindings(bindings map[string]openAIWSSessionConnBinding, now time.Time, maxScan int) {
+ if len(bindings) == 0 || maxScan <= 0 {
+ return
+ }
+ scanned := 0
+ for key, binding := range bindings {
+ if now.After(binding.expiresAt) {
+ delete(bindings, key)
+ }
+ scanned++
+ if scanned >= maxScan {
+ break
+ }
+ }
+}
+
+func ensureBindingCapacity[T any](bindings map[string]T, incomingKey string, maxEntries int) {
+ if len(bindings) < maxEntries || maxEntries <= 0 {
+ return
+ }
+ if _, exists := bindings[incomingKey]; exists {
+ return
+ }
+ // 固定上限保护:淘汰任意一项,优先保证内存有界。
+ for key := range bindings {
+ delete(bindings, key)
+ return
+ }
+}
+
+func normalizeOpenAIWSResponseID(responseID string) string {
+ return strings.TrimSpace(responseID)
+}
+
+func openAIWSResponseAccountCacheKey(responseID string) string {
+ sum := sha256.Sum256([]byte(responseID))
+ return openAIWSResponseAccountCachePrefix + hex.EncodeToString(sum[:])
+}
+
+func normalizeOpenAIWSTTL(ttl time.Duration) time.Duration {
+ if ttl <= 0 {
+ return time.Hour
+ }
+ return ttl
+}
+
+func openAIWSSessionTurnStateKey(groupID int64, sessionHash string) string {
+ hash := strings.TrimSpace(sessionHash)
+ if hash == "" {
+ return ""
+ }
+ return fmt.Sprintf("%d:%s", groupID, hash)
+}
+
+func withOpenAIWSStateStoreRedisTimeout(ctx context.Context) (context.Context, context.CancelFunc) {
+ if ctx == nil {
+ ctx = context.Background()
+ }
+ return context.WithTimeout(ctx, openAIWSStateStoreRedisTimeout)
+}
diff --git a/backend/internal/service/openai_ws_state_store_test.go b/backend/internal/service/openai_ws_state_store_test.go
new file mode 100644
index 00000000..235d4233
--- /dev/null
+++ b/backend/internal/service/openai_ws_state_store_test.go
@@ -0,0 +1,235 @@
+package service
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestOpenAIWSStateStore_BindGetDeleteResponseAccount(t *testing.T) {
+ cache := &stubGatewayCache{}
+ store := NewOpenAIWSStateStore(cache)
+ ctx := context.Background()
+ groupID := int64(7)
+
+ require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_abc", 101, time.Minute))
+
+ accountID, err := store.GetResponseAccount(ctx, groupID, "resp_abc")
+ require.NoError(t, err)
+ require.Equal(t, int64(101), accountID)
+
+ require.NoError(t, store.DeleteResponseAccount(ctx, groupID, "resp_abc"))
+ accountID, err = store.GetResponseAccount(ctx, groupID, "resp_abc")
+ require.NoError(t, err)
+ require.Zero(t, accountID)
+}
+
+func TestOpenAIWSStateStore_ResponseConnTTL(t *testing.T) {
+ store := NewOpenAIWSStateStore(nil)
+ store.BindResponseConn("resp_conn", "conn_1", 30*time.Millisecond)
+
+ connID, ok := store.GetResponseConn("resp_conn")
+ require.True(t, ok)
+ require.Equal(t, "conn_1", connID)
+
+ time.Sleep(60 * time.Millisecond)
+ _, ok = store.GetResponseConn("resp_conn")
+ require.False(t, ok)
+}
+
+func TestOpenAIWSStateStore_SessionTurnStateTTL(t *testing.T) {
+ store := NewOpenAIWSStateStore(nil)
+ store.BindSessionTurnState(9, "session_hash_1", "turn_state_1", 30*time.Millisecond)
+
+ state, ok := store.GetSessionTurnState(9, "session_hash_1")
+ require.True(t, ok)
+ require.Equal(t, "turn_state_1", state)
+
+ // group 隔离
+ _, ok = store.GetSessionTurnState(10, "session_hash_1")
+ require.False(t, ok)
+
+ time.Sleep(60 * time.Millisecond)
+ _, ok = store.GetSessionTurnState(9, "session_hash_1")
+ require.False(t, ok)
+}
+
+func TestOpenAIWSStateStore_SessionConnTTL(t *testing.T) {
+ store := NewOpenAIWSStateStore(nil)
+ store.BindSessionConn(9, "session_hash_conn_1", "conn_1", 30*time.Millisecond)
+
+ connID, ok := store.GetSessionConn(9, "session_hash_conn_1")
+ require.True(t, ok)
+ require.Equal(t, "conn_1", connID)
+
+ // group 隔离
+ _, ok = store.GetSessionConn(10, "session_hash_conn_1")
+ require.False(t, ok)
+
+ time.Sleep(60 * time.Millisecond)
+ _, ok = store.GetSessionConn(9, "session_hash_conn_1")
+ require.False(t, ok)
+}
+
+func TestOpenAIWSStateStore_GetResponseAccount_NoStaleAfterCacheMiss(t *testing.T) {
+ cache := &stubGatewayCache{sessionBindings: map[string]int64{}}
+ store := NewOpenAIWSStateStore(cache)
+ ctx := context.Background()
+ groupID := int64(17)
+ responseID := "resp_cache_stale"
+ cacheKey := openAIWSResponseAccountCacheKey(responseID)
+
+ cache.sessionBindings[cacheKey] = 501
+ accountID, err := store.GetResponseAccount(ctx, groupID, responseID)
+ require.NoError(t, err)
+ require.Equal(t, int64(501), accountID)
+
+ delete(cache.sessionBindings, cacheKey)
+ accountID, err = store.GetResponseAccount(ctx, groupID, responseID)
+ require.NoError(t, err)
+ require.Zero(t, accountID, "上游缓存失效后不应继续命中本地陈旧映射")
+}
+
+func TestOpenAIWSStateStore_MaybeCleanupRemovesExpiredIncrementally(t *testing.T) {
+ raw := NewOpenAIWSStateStore(nil)
+ store, ok := raw.(*defaultOpenAIWSStateStore)
+ require.True(t, ok)
+
+ expiredAt := time.Now().Add(-time.Minute)
+ total := 2048
+ store.responseToConnMu.Lock()
+ for i := 0; i < total; i++ {
+ store.responseToConn[fmt.Sprintf("resp_%d", i)] = openAIWSConnBinding{
+ connID: "conn_incremental",
+ expiresAt: expiredAt,
+ }
+ }
+ store.responseToConnMu.Unlock()
+
+ store.lastCleanupUnixNano.Store(time.Now().Add(-2 * openAIWSStateStoreCleanupInterval).UnixNano())
+ store.maybeCleanup()
+
+ store.responseToConnMu.RLock()
+ remainingAfterFirst := len(store.responseToConn)
+ store.responseToConnMu.RUnlock()
+ require.Less(t, remainingAfterFirst, total, "单轮 cleanup 应至少有进展")
+ require.Greater(t, remainingAfterFirst, 0, "增量清理不要求单轮清空全部键")
+
+ for i := 0; i < 8; i++ {
+ store.lastCleanupUnixNano.Store(time.Now().Add(-2 * openAIWSStateStoreCleanupInterval).UnixNano())
+ store.maybeCleanup()
+ }
+
+ store.responseToConnMu.RLock()
+ remaining := len(store.responseToConn)
+ store.responseToConnMu.RUnlock()
+ require.Zero(t, remaining, "多轮 cleanup 后应逐步清空全部过期键")
+}
+
+func TestEnsureBindingCapacity_EvictsOneWhenMapIsFull(t *testing.T) {
+ bindings := map[string]int{
+ "a": 1,
+ "b": 2,
+ }
+
+ ensureBindingCapacity(bindings, "c", 2)
+ bindings["c"] = 3
+
+ require.Len(t, bindings, 2)
+ require.Equal(t, 3, bindings["c"])
+}
+
+func TestEnsureBindingCapacity_DoesNotEvictWhenUpdatingExistingKey(t *testing.T) {
+ bindings := map[string]int{
+ "a": 1,
+ "b": 2,
+ }
+
+ ensureBindingCapacity(bindings, "a", 2)
+ bindings["a"] = 9
+
+ require.Len(t, bindings, 2)
+ require.Equal(t, 9, bindings["a"])
+}
+
+type openAIWSStateStoreTimeoutProbeCache struct {
+ setHasDeadline bool
+ getHasDeadline bool
+ deleteHasDeadline bool
+ setDeadlineDelta time.Duration
+ getDeadlineDelta time.Duration
+ delDeadlineDelta time.Duration
+}
+
+func (c *openAIWSStateStoreTimeoutProbeCache) GetSessionAccountID(ctx context.Context, _ int64, _ string) (int64, error) {
+ if deadline, ok := ctx.Deadline(); ok {
+ c.getHasDeadline = true
+ c.getDeadlineDelta = time.Until(deadline)
+ }
+ return 123, nil
+}
+
+func (c *openAIWSStateStoreTimeoutProbeCache) SetSessionAccountID(ctx context.Context, _ int64, _ string, _ int64, _ time.Duration) error {
+ if deadline, ok := ctx.Deadline(); ok {
+ c.setHasDeadline = true
+ c.setDeadlineDelta = time.Until(deadline)
+ }
+ return errors.New("set failed")
+}
+
+func (c *openAIWSStateStoreTimeoutProbeCache) RefreshSessionTTL(context.Context, int64, string, time.Duration) error {
+ return nil
+}
+
+func (c *openAIWSStateStoreTimeoutProbeCache) DeleteSessionAccountID(ctx context.Context, _ int64, _ string) error {
+ if deadline, ok := ctx.Deadline(); ok {
+ c.deleteHasDeadline = true
+ c.delDeadlineDelta = time.Until(deadline)
+ }
+ return nil
+}
+
+func TestOpenAIWSStateStore_RedisOpsUseShortTimeout(t *testing.T) {
+ probe := &openAIWSStateStoreTimeoutProbeCache{}
+ store := NewOpenAIWSStateStore(probe)
+ ctx := context.Background()
+ groupID := int64(5)
+
+ err := store.BindResponseAccount(ctx, groupID, "resp_timeout_probe", 11, time.Minute)
+ require.Error(t, err)
+
+ accountID, getErr := store.GetResponseAccount(ctx, groupID, "resp_timeout_probe")
+ require.NoError(t, getErr)
+ require.Equal(t, int64(11), accountID, "本地缓存命中应优先返回已绑定账号")
+
+ require.NoError(t, store.DeleteResponseAccount(ctx, groupID, "resp_timeout_probe"))
+
+ require.True(t, probe.setHasDeadline, "SetSessionAccountID 应携带独立超时上下文")
+ require.True(t, probe.deleteHasDeadline, "DeleteSessionAccountID 应携带独立超时上下文")
+ require.False(t, probe.getHasDeadline, "GetSessionAccountID 本用例应由本地缓存命中,不触发 Redis 读取")
+ require.Greater(t, probe.setDeadlineDelta, 2*time.Second)
+ require.LessOrEqual(t, probe.setDeadlineDelta, 3*time.Second)
+ require.Greater(t, probe.delDeadlineDelta, 2*time.Second)
+ require.LessOrEqual(t, probe.delDeadlineDelta, 3*time.Second)
+
+ probe2 := &openAIWSStateStoreTimeoutProbeCache{}
+ store2 := NewOpenAIWSStateStore(probe2)
+ accountID2, err2 := store2.GetResponseAccount(ctx, groupID, "resp_cache_only")
+ require.NoError(t, err2)
+ require.Equal(t, int64(123), accountID2)
+ require.True(t, probe2.getHasDeadline, "GetSessionAccountID 在缓存未命中时应携带独立超时上下文")
+ require.Greater(t, probe2.getDeadlineDelta, 2*time.Second)
+ require.LessOrEqual(t, probe2.getDeadlineDelta, 3*time.Second)
+}
+
+func TestWithOpenAIWSStateStoreRedisTimeout_WithParentContext(t *testing.T) {
+ ctx, cancel := withOpenAIWSStateStoreRedisTimeout(context.Background())
+ defer cancel()
+ require.NotNil(t, ctx)
+ _, ok := ctx.Deadline()
+ require.True(t, ok, "应附加短超时")
+}
diff --git a/backend/internal/service/ops_retry.go b/backend/internal/service/ops_retry.go
index 23a524ad..f0daa3e2 100644
--- a/backend/internal/service/ops_retry.go
+++ b/backend/internal/service/ops_retry.go
@@ -13,7 +13,6 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/domain"
- "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/gin-gonic/gin"
"github.com/lib/pq"
@@ -480,7 +479,7 @@ func (s *OpsService) executeClientRetry(ctx context.Context, reqType opsRetryReq
attemptCtx := ctx
if switches > 0 {
- attemptCtx = context.WithValue(attemptCtx, ctxkey.AccountSwitchCount, switches)
+ attemptCtx = WithAccountSwitchCount(attemptCtx, switches, false)
}
exec := func() *opsRetryExecution {
defer selection.ReleaseFunc()
@@ -675,6 +674,7 @@ func newOpsRetryContext(ctx context.Context, errorLog *OpsErrorLogDetail) (*gin.
}
c.Request = req
+ SetOpenAIClientTransport(c, OpenAIClientTransportHTTP)
return c, w
}
diff --git a/backend/internal/service/ops_retry_context_test.go b/backend/internal/service/ops_retry_context_test.go
new file mode 100644
index 00000000..a8c26ee4
--- /dev/null
+++ b/backend/internal/service/ops_retry_context_test.go
@@ -0,0 +1,47 @@
+package service
+
+import (
+ "context"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestNewOpsRetryContext_SetsHTTPTransportAndRequestHeaders(t *testing.T) {
+ errorLog := &OpsErrorLogDetail{
+ OpsErrorLog: OpsErrorLog{
+ RequestPath: "/openai/v1/responses",
+ },
+ UserAgent: "ops-retry-agent/1.0",
+ RequestHeaders: `{
+ "anthropic-beta":"beta-v1",
+ "ANTHROPIC-VERSION":"2023-06-01",
+ "authorization":"Bearer should-not-forward"
+ }`,
+ }
+
+ c, w := newOpsRetryContext(context.Background(), errorLog)
+ require.NotNil(t, c)
+ require.NotNil(t, w)
+ require.NotNil(t, c.Request)
+
+ require.Equal(t, "/openai/v1/responses", c.Request.URL.Path)
+ require.Equal(t, "application/json", c.Request.Header.Get("Content-Type"))
+ require.Equal(t, "ops-retry-agent/1.0", c.Request.Header.Get("User-Agent"))
+ require.Equal(t, "beta-v1", c.Request.Header.Get("anthropic-beta"))
+ require.Equal(t, "2023-06-01", c.Request.Header.Get("anthropic-version"))
+ require.Empty(t, c.Request.Header.Get("authorization"), "未在白名单内的敏感头不应被重放")
+ require.Equal(t, OpenAIClientTransportHTTP, GetOpenAIClientTransport(c))
+}
+
+func TestNewOpsRetryContext_InvalidHeadersJSONStillSetsHTTPTransport(t *testing.T) {
+ errorLog := &OpsErrorLogDetail{
+ RequestHeaders: "{invalid-json",
+ }
+
+ c, _ := newOpsRetryContext(context.Background(), errorLog)
+ require.NotNil(t, c)
+ require.NotNil(t, c.Request)
+ require.Equal(t, "/", c.Request.URL.Path)
+ require.Equal(t, OpenAIClientTransportHTTP, GetOpenAIClientTransport(c))
+}
diff --git a/backend/internal/service/ops_upstream_context.go b/backend/internal/service/ops_upstream_context.go
index 23c154ce..21e09c43 100644
--- a/backend/internal/service/ops_upstream_context.go
+++ b/backend/internal/service/ops_upstream_context.go
@@ -27,6 +27,11 @@ const (
OpsUpstreamLatencyMsKey = "ops_upstream_latency_ms"
OpsResponseLatencyMsKey = "ops_response_latency_ms"
OpsTimeToFirstTokenMsKey = "ops_time_to_first_token_ms"
+ // OpenAI WS 关键观测字段
+ OpsOpenAIWSQueueWaitMsKey = "ops_openai_ws_queue_wait_ms"
+ OpsOpenAIWSConnPickMsKey = "ops_openai_ws_conn_pick_ms"
+ OpsOpenAIWSConnReusedKey = "ops_openai_ws_conn_reused"
+ OpsOpenAIWSConnIDKey = "ops_openai_ws_conn_id"
// OpsSkipPassthroughKey 由 applyErrorPassthroughRule 在命中 skip_monitoring=true 的规则时设置。
// ops_error_logger 中间件检查此 key,为 true 时跳过错误记录。
diff --git a/backend/internal/service/ratelimit_service.go b/backend/internal/service/ratelimit_service.go
index fcc7c4a0..d4d70536 100644
--- a/backend/internal/service/ratelimit_service.go
+++ b/backend/internal/service/ratelimit_service.go
@@ -11,6 +11,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/logger"
)
// RateLimitService 处理限流和过载状态管理
@@ -33,6 +34,10 @@ type geminiUsageCacheEntry struct {
totals GeminiUsageTotals
}
+type geminiUsageTotalsBatchProvider interface {
+ GetGeminiUsageTotalsBatch(ctx context.Context, accountIDs []int64, startTime, endTime time.Time) (map[int64]GeminiUsageTotals, error)
+}
+
const geminiPrecheckCacheTTL = time.Minute
// NewRateLimitService 创建RateLimitService实例
@@ -162,6 +167,17 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
if upstreamMsg != "" {
msg = "Access forbidden (403): " + upstreamMsg
}
+ logger.LegacyPrintf(
+ "service.ratelimit",
+ "[HandleUpstreamErrorRaw] account_id=%d platform=%s type=%s status=403 request_id=%s cf_ray=%s upstream_msg=%s raw_body=%s",
+ account.ID,
+ account.Platform,
+ account.Type,
+ strings.TrimSpace(headers.Get("x-request-id")),
+ strings.TrimSpace(headers.Get("cf-ray")),
+ upstreamMsg,
+ truncateForLog(responseBody, 1024),
+ )
s.handleAuthError(ctx, account, msg)
shouldDisable = true
case 429:
@@ -225,7 +241,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account,
start := geminiDailyWindowStart(now)
totals, ok := s.getGeminiUsageTotals(account.ID, start, now)
if !ok {
- stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID, 0, nil, nil)
+ stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID, 0, nil, nil, nil)
if err != nil {
return true, err
}
@@ -272,7 +288,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account,
if limit > 0 {
start := now.Truncate(time.Minute)
- stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID, 0, nil, nil)
+ stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID, 0, nil, nil, nil)
if err != nil {
return true, err
}
@@ -302,6 +318,218 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account,
return true, nil
}
+// PreCheckUsageBatch performs quota precheck for multiple accounts in one request.
+// Returned map value=false means the account should be skipped.
+func (s *RateLimitService) PreCheckUsageBatch(ctx context.Context, accounts []*Account, requestedModel string) (map[int64]bool, error) {
+ result := make(map[int64]bool, len(accounts))
+ for _, account := range accounts {
+ if account == nil {
+ continue
+ }
+ result[account.ID] = true
+ }
+
+ if len(accounts) == 0 || requestedModel == "" {
+ return result, nil
+ }
+ if s.usageRepo == nil || s.geminiQuotaService == nil {
+ return result, nil
+ }
+
+ modelClass := geminiModelClassFromName(requestedModel)
+ now := time.Now()
+ dailyStart := geminiDailyWindowStart(now)
+ minuteStart := now.Truncate(time.Minute)
+
+ type quotaAccount struct {
+ account *Account
+ quota GeminiQuota
+ }
+ quotaAccounts := make([]quotaAccount, 0, len(accounts))
+ for _, account := range accounts {
+ if account == nil || account.Platform != PlatformGemini {
+ continue
+ }
+ quota, ok := s.geminiQuotaService.QuotaForAccount(ctx, account)
+ if !ok {
+ continue
+ }
+ quotaAccounts = append(quotaAccounts, quotaAccount{
+ account: account,
+ quota: quota,
+ })
+ }
+ if len(quotaAccounts) == 0 {
+ return result, nil
+ }
+
+ // 1) Daily precheck (cached + batch DB fallback)
+ dailyTotalsByID := make(map[int64]GeminiUsageTotals, len(quotaAccounts))
+ dailyMissIDs := make([]int64, 0, len(quotaAccounts))
+ for _, item := range quotaAccounts {
+ limit := geminiDailyLimit(item.quota, modelClass)
+ if limit <= 0 {
+ continue
+ }
+ accountID := item.account.ID
+ if totals, ok := s.getGeminiUsageTotals(accountID, dailyStart, now); ok {
+ dailyTotalsByID[accountID] = totals
+ continue
+ }
+ dailyMissIDs = append(dailyMissIDs, accountID)
+ }
+ if len(dailyMissIDs) > 0 {
+ totalsBatch, err := s.getGeminiUsageTotalsBatch(ctx, dailyMissIDs, dailyStart, now)
+ if err != nil {
+ return result, err
+ }
+ for _, accountID := range dailyMissIDs {
+ totals := totalsBatch[accountID]
+ dailyTotalsByID[accountID] = totals
+ s.setGeminiUsageTotals(accountID, dailyStart, now, totals)
+ }
+ }
+ for _, item := range quotaAccounts {
+ limit := geminiDailyLimit(item.quota, modelClass)
+ if limit <= 0 {
+ continue
+ }
+ accountID := item.account.ID
+ used := geminiUsedRequests(item.quota, modelClass, dailyTotalsByID[accountID], true)
+ if used >= limit {
+ resetAt := geminiDailyResetTime(now)
+ slog.Info("gemini_precheck_daily_quota_reached_batch", "account_id", accountID, "used", used, "limit", limit, "reset_at", resetAt)
+ result[accountID] = false
+ }
+ }
+
+ // 2) Minute precheck (batch DB)
+ minuteIDs := make([]int64, 0, len(quotaAccounts))
+ for _, item := range quotaAccounts {
+ accountID := item.account.ID
+ if !result[accountID] {
+ continue
+ }
+ if geminiMinuteLimit(item.quota, modelClass) <= 0 {
+ continue
+ }
+ minuteIDs = append(minuteIDs, accountID)
+ }
+ if len(minuteIDs) == 0 {
+ return result, nil
+ }
+
+ minuteTotalsByID, err := s.getGeminiUsageTotalsBatch(ctx, minuteIDs, minuteStart, now)
+ if err != nil {
+ return result, err
+ }
+ for _, item := range quotaAccounts {
+ accountID := item.account.ID
+ if !result[accountID] {
+ continue
+ }
+
+ limit := geminiMinuteLimit(item.quota, modelClass)
+ if limit <= 0 {
+ continue
+ }
+
+ used := geminiUsedRequests(item.quota, modelClass, minuteTotalsByID[accountID], false)
+ if used >= limit {
+ resetAt := minuteStart.Add(time.Minute)
+ slog.Info("gemini_precheck_minute_quota_reached_batch", "account_id", accountID, "used", used, "limit", limit, "reset_at", resetAt)
+ result[accountID] = false
+ }
+ }
+
+ return result, nil
+}
+
+func (s *RateLimitService) getGeminiUsageTotalsBatch(ctx context.Context, accountIDs []int64, start, end time.Time) (map[int64]GeminiUsageTotals, error) {
+ result := make(map[int64]GeminiUsageTotals, len(accountIDs))
+ if len(accountIDs) == 0 {
+ return result, nil
+ }
+
+ ids := make([]int64, 0, len(accountIDs))
+ seen := make(map[int64]struct{}, len(accountIDs))
+ for _, accountID := range accountIDs {
+ if accountID <= 0 {
+ continue
+ }
+ if _, ok := seen[accountID]; ok {
+ continue
+ }
+ seen[accountID] = struct{}{}
+ ids = append(ids, accountID)
+ }
+ if len(ids) == 0 {
+ return result, nil
+ }
+
+ if batchReader, ok := s.usageRepo.(geminiUsageTotalsBatchProvider); ok {
+ stats, err := batchReader.GetGeminiUsageTotalsBatch(ctx, ids, start, end)
+ if err != nil {
+ return nil, err
+ }
+ for _, accountID := range ids {
+ result[accountID] = stats[accountID]
+ }
+ return result, nil
+ }
+
+ for _, accountID := range ids {
+ stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, end, 0, 0, accountID, 0, nil, nil, nil)
+ if err != nil {
+ return nil, err
+ }
+ result[accountID] = geminiAggregateUsage(stats)
+ }
+ return result, nil
+}
+
+func geminiDailyLimit(quota GeminiQuota, modelClass geminiModelClass) int64 {
+ if quota.SharedRPD > 0 {
+ return quota.SharedRPD
+ }
+ switch modelClass {
+ case geminiModelFlash:
+ return quota.FlashRPD
+ default:
+ return quota.ProRPD
+ }
+}
+
+func geminiMinuteLimit(quota GeminiQuota, modelClass geminiModelClass) int64 {
+ if quota.SharedRPM > 0 {
+ return quota.SharedRPM
+ }
+ switch modelClass {
+ case geminiModelFlash:
+ return quota.FlashRPM
+ default:
+ return quota.ProRPM
+ }
+}
+
+func geminiUsedRequests(quota GeminiQuota, modelClass geminiModelClass, totals GeminiUsageTotals, daily bool) int64 {
+ if daily {
+ if quota.SharedRPD > 0 {
+ return totals.ProRequests + totals.FlashRequests
+ }
+ } else {
+ if quota.SharedRPM > 0 {
+ return totals.ProRequests + totals.FlashRequests
+ }
+ }
+ switch modelClass {
+ case geminiModelFlash:
+ return totals.FlashRequests
+ default:
+ return totals.ProRequests
+ }
+}
+
func (s *RateLimitService) getGeminiUsageTotals(accountID int64, windowStart, now time.Time) (GeminiUsageTotals, bool) {
s.usageCacheMu.RLock()
defer s.usageCacheMu.RUnlock()
diff --git a/backend/internal/service/redeem_service.go b/backend/internal/service/redeem_service.go
index ad277ca0..b22da752 100644
--- a/backend/internal/service/redeem_service.go
+++ b/backend/internal/service/redeem_service.go
@@ -174,6 +174,33 @@ func (s *RedeemService) GenerateCodes(ctx context.Context, req GenerateCodesRequ
return codes, nil
}
+// CreateCode creates a redeem code with caller-provided code value.
+// It is primarily used by admin integrations that require an external order ID
+// to be mapped to a deterministic redeem code.
+func (s *RedeemService) CreateCode(ctx context.Context, code *RedeemCode) error {
+ if code == nil {
+ return errors.New("redeem code is required")
+ }
+ code.Code = strings.TrimSpace(code.Code)
+ if code.Code == "" {
+ return errors.New("code is required")
+ }
+ if code.Type == "" {
+ code.Type = RedeemTypeBalance
+ }
+ if code.Type != RedeemTypeInvitation && code.Value <= 0 {
+ return errors.New("value must be greater than 0")
+ }
+ if code.Status == "" {
+ code.Status = StatusUnused
+ }
+
+ if err := s.redeemRepo.Create(ctx, code); err != nil {
+ return fmt.Errorf("create redeem code: %w", err)
+ }
+ return nil
+}
+
// checkRedeemRateLimit 检查用户兑换错误次数是否超限
func (s *RedeemService) checkRedeemRateLimit(ctx context.Context, userID int64) error {
if s.cache == nil {
diff --git a/backend/internal/service/request_metadata.go b/backend/internal/service/request_metadata.go
new file mode 100644
index 00000000..5c81bbf1
--- /dev/null
+++ b/backend/internal/service/request_metadata.go
@@ -0,0 +1,216 @@
+package service
+
+import (
+ "context"
+ "sync/atomic"
+
+ "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
+)
+
+type requestMetadataContextKey struct{}
+
+var requestMetadataKey = requestMetadataContextKey{}
+
+type RequestMetadata struct {
+ IsMaxTokensOneHaikuRequest *bool
+ ThinkingEnabled *bool
+ PrefetchedStickyAccountID *int64
+ PrefetchedStickyGroupID *int64
+ SingleAccountRetry *bool
+ AccountSwitchCount *int
+}
+
+var (
+ requestMetadataFallbackIsMaxTokensOneHaikuTotal atomic.Int64
+ requestMetadataFallbackThinkingEnabledTotal atomic.Int64
+ requestMetadataFallbackPrefetchedStickyAccount atomic.Int64
+ requestMetadataFallbackPrefetchedStickyGroup atomic.Int64
+ requestMetadataFallbackSingleAccountRetryTotal atomic.Int64
+ requestMetadataFallbackAccountSwitchCountTotal atomic.Int64
+)
+
+func RequestMetadataFallbackStats() (isMaxTokensOneHaiku, thinkingEnabled, prefetchedStickyAccount, prefetchedStickyGroup, singleAccountRetry, accountSwitchCount int64) {
+ return requestMetadataFallbackIsMaxTokensOneHaikuTotal.Load(),
+ requestMetadataFallbackThinkingEnabledTotal.Load(),
+ requestMetadataFallbackPrefetchedStickyAccount.Load(),
+ requestMetadataFallbackPrefetchedStickyGroup.Load(),
+ requestMetadataFallbackSingleAccountRetryTotal.Load(),
+ requestMetadataFallbackAccountSwitchCountTotal.Load()
+}
+
+func metadataFromContext(ctx context.Context) *RequestMetadata {
+ if ctx == nil {
+ return nil
+ }
+ md, _ := ctx.Value(requestMetadataKey).(*RequestMetadata)
+ return md
+}
+
+func updateRequestMetadata(
+ ctx context.Context,
+ bridgeOldKeys bool,
+ update func(md *RequestMetadata),
+ legacyBridge func(ctx context.Context) context.Context,
+) context.Context {
+ if ctx == nil {
+ return nil
+ }
+ current := metadataFromContext(ctx)
+ next := &RequestMetadata{}
+ if current != nil {
+ *next = *current
+ }
+ update(next)
+ ctx = context.WithValue(ctx, requestMetadataKey, next)
+ if bridgeOldKeys && legacyBridge != nil {
+ ctx = legacyBridge(ctx)
+ }
+ return ctx
+}
+
+func WithIsMaxTokensOneHaikuRequest(ctx context.Context, value bool, bridgeOldKeys bool) context.Context {
+ return updateRequestMetadata(ctx, bridgeOldKeys, func(md *RequestMetadata) {
+ v := value
+ md.IsMaxTokensOneHaikuRequest = &v
+ }, func(base context.Context) context.Context {
+ return context.WithValue(base, ctxkey.IsMaxTokensOneHaikuRequest, value)
+ })
+}
+
+func WithThinkingEnabled(ctx context.Context, value bool, bridgeOldKeys bool) context.Context {
+ return updateRequestMetadata(ctx, bridgeOldKeys, func(md *RequestMetadata) {
+ v := value
+ md.ThinkingEnabled = &v
+ }, func(base context.Context) context.Context {
+ return context.WithValue(base, ctxkey.ThinkingEnabled, value)
+ })
+}
+
+func WithPrefetchedStickySession(ctx context.Context, accountID, groupID int64, bridgeOldKeys bool) context.Context {
+ return updateRequestMetadata(ctx, bridgeOldKeys, func(md *RequestMetadata) {
+ account := accountID
+ group := groupID
+ md.PrefetchedStickyAccountID = &account
+ md.PrefetchedStickyGroupID = &group
+ }, func(base context.Context) context.Context {
+ bridged := context.WithValue(base, ctxkey.PrefetchedStickyAccountID, accountID)
+ return context.WithValue(bridged, ctxkey.PrefetchedStickyGroupID, groupID)
+ })
+}
+
+func WithSingleAccountRetry(ctx context.Context, value bool, bridgeOldKeys bool) context.Context {
+ return updateRequestMetadata(ctx, bridgeOldKeys, func(md *RequestMetadata) {
+ v := value
+ md.SingleAccountRetry = &v
+ }, func(base context.Context) context.Context {
+ return context.WithValue(base, ctxkey.SingleAccountRetry, value)
+ })
+}
+
+func WithAccountSwitchCount(ctx context.Context, value int, bridgeOldKeys bool) context.Context {
+ return updateRequestMetadata(ctx, bridgeOldKeys, func(md *RequestMetadata) {
+ v := value
+ md.AccountSwitchCount = &v
+ }, func(base context.Context) context.Context {
+ return context.WithValue(base, ctxkey.AccountSwitchCount, value)
+ })
+}
+
+func IsMaxTokensOneHaikuRequestFromContext(ctx context.Context) (bool, bool) {
+ if md := metadataFromContext(ctx); md != nil && md.IsMaxTokensOneHaikuRequest != nil {
+ return *md.IsMaxTokensOneHaikuRequest, true
+ }
+ if ctx == nil {
+ return false, false
+ }
+ if value, ok := ctx.Value(ctxkey.IsMaxTokensOneHaikuRequest).(bool); ok {
+ requestMetadataFallbackIsMaxTokensOneHaikuTotal.Add(1)
+ return value, true
+ }
+ return false, false
+}
+
+func ThinkingEnabledFromContext(ctx context.Context) (bool, bool) {
+ if md := metadataFromContext(ctx); md != nil && md.ThinkingEnabled != nil {
+ return *md.ThinkingEnabled, true
+ }
+ if ctx == nil {
+ return false, false
+ }
+ if value, ok := ctx.Value(ctxkey.ThinkingEnabled).(bool); ok {
+ requestMetadataFallbackThinkingEnabledTotal.Add(1)
+ return value, true
+ }
+ return false, false
+}
+
+func PrefetchedStickyGroupIDFromContext(ctx context.Context) (int64, bool) {
+ if md := metadataFromContext(ctx); md != nil && md.PrefetchedStickyGroupID != nil {
+ return *md.PrefetchedStickyGroupID, true
+ }
+ if ctx == nil {
+ return 0, false
+ }
+ v := ctx.Value(ctxkey.PrefetchedStickyGroupID)
+ switch t := v.(type) {
+ case int64:
+ requestMetadataFallbackPrefetchedStickyGroup.Add(1)
+ return t, true
+ case int:
+ requestMetadataFallbackPrefetchedStickyGroup.Add(1)
+ return int64(t), true
+ }
+ return 0, false
+}
+
+func PrefetchedStickyAccountIDFromContext(ctx context.Context) (int64, bool) {
+ if md := metadataFromContext(ctx); md != nil && md.PrefetchedStickyAccountID != nil {
+ return *md.PrefetchedStickyAccountID, true
+ }
+ if ctx == nil {
+ return 0, false
+ }
+ v := ctx.Value(ctxkey.PrefetchedStickyAccountID)
+ switch t := v.(type) {
+ case int64:
+ requestMetadataFallbackPrefetchedStickyAccount.Add(1)
+ return t, true
+ case int:
+ requestMetadataFallbackPrefetchedStickyAccount.Add(1)
+ return int64(t), true
+ }
+ return 0, false
+}
+
+func SingleAccountRetryFromContext(ctx context.Context) (bool, bool) {
+ if md := metadataFromContext(ctx); md != nil && md.SingleAccountRetry != nil {
+ return *md.SingleAccountRetry, true
+ }
+ if ctx == nil {
+ return false, false
+ }
+ if value, ok := ctx.Value(ctxkey.SingleAccountRetry).(bool); ok {
+ requestMetadataFallbackSingleAccountRetryTotal.Add(1)
+ return value, true
+ }
+ return false, false
+}
+
+func AccountSwitchCountFromContext(ctx context.Context) (int, bool) {
+ if md := metadataFromContext(ctx); md != nil && md.AccountSwitchCount != nil {
+ return *md.AccountSwitchCount, true
+ }
+ if ctx == nil {
+ return 0, false
+ }
+ v := ctx.Value(ctxkey.AccountSwitchCount)
+ switch t := v.(type) {
+ case int:
+ requestMetadataFallbackAccountSwitchCountTotal.Add(1)
+ return t, true
+ case int64:
+ requestMetadataFallbackAccountSwitchCountTotal.Add(1)
+ return int(t), true
+ }
+ return 0, false
+}
diff --git a/backend/internal/service/request_metadata_test.go b/backend/internal/service/request_metadata_test.go
new file mode 100644
index 00000000..7d192699
--- /dev/null
+++ b/backend/internal/service/request_metadata_test.go
@@ -0,0 +1,119 @@
+package service
+
+import (
+ "context"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
+ "github.com/stretchr/testify/require"
+)
+
+func TestRequestMetadataWriteAndRead_NoBridge(t *testing.T) {
+ ctx := context.Background()
+ ctx = WithIsMaxTokensOneHaikuRequest(ctx, true, false)
+ ctx = WithThinkingEnabled(ctx, true, false)
+ ctx = WithPrefetchedStickySession(ctx, 123, 456, false)
+ ctx = WithSingleAccountRetry(ctx, true, false)
+ ctx = WithAccountSwitchCount(ctx, 2, false)
+
+ isHaiku, ok := IsMaxTokensOneHaikuRequestFromContext(ctx)
+ require.True(t, ok)
+ require.True(t, isHaiku)
+
+ thinking, ok := ThinkingEnabledFromContext(ctx)
+ require.True(t, ok)
+ require.True(t, thinking)
+
+ accountID, ok := PrefetchedStickyAccountIDFromContext(ctx)
+ require.True(t, ok)
+ require.Equal(t, int64(123), accountID)
+
+ groupID, ok := PrefetchedStickyGroupIDFromContext(ctx)
+ require.True(t, ok)
+ require.Equal(t, int64(456), groupID)
+
+ singleRetry, ok := SingleAccountRetryFromContext(ctx)
+ require.True(t, ok)
+ require.True(t, singleRetry)
+
+ switchCount, ok := AccountSwitchCountFromContext(ctx)
+ require.True(t, ok)
+ require.Equal(t, 2, switchCount)
+
+ require.Nil(t, ctx.Value(ctxkey.IsMaxTokensOneHaikuRequest))
+ require.Nil(t, ctx.Value(ctxkey.ThinkingEnabled))
+ require.Nil(t, ctx.Value(ctxkey.PrefetchedStickyAccountID))
+ require.Nil(t, ctx.Value(ctxkey.PrefetchedStickyGroupID))
+ require.Nil(t, ctx.Value(ctxkey.SingleAccountRetry))
+ require.Nil(t, ctx.Value(ctxkey.AccountSwitchCount))
+}
+
+func TestRequestMetadataWrite_BridgeLegacyKeys(t *testing.T) {
+ ctx := context.Background()
+ ctx = WithIsMaxTokensOneHaikuRequest(ctx, true, true)
+ ctx = WithThinkingEnabled(ctx, true, true)
+ ctx = WithPrefetchedStickySession(ctx, 123, 456, true)
+ ctx = WithSingleAccountRetry(ctx, true, true)
+ ctx = WithAccountSwitchCount(ctx, 2, true)
+
+ require.Equal(t, true, ctx.Value(ctxkey.IsMaxTokensOneHaikuRequest))
+ require.Equal(t, true, ctx.Value(ctxkey.ThinkingEnabled))
+ require.Equal(t, int64(123), ctx.Value(ctxkey.PrefetchedStickyAccountID))
+ require.Equal(t, int64(456), ctx.Value(ctxkey.PrefetchedStickyGroupID))
+ require.Equal(t, true, ctx.Value(ctxkey.SingleAccountRetry))
+ require.Equal(t, 2, ctx.Value(ctxkey.AccountSwitchCount))
+}
+
+func TestRequestMetadataRead_LegacyFallbackAndStats(t *testing.T) {
+ beforeHaiku, beforeThinking, beforeAccount, beforeGroup, beforeSingleRetry, beforeSwitchCount := RequestMetadataFallbackStats()
+
+ ctx := context.Background()
+ ctx = context.WithValue(ctx, ctxkey.IsMaxTokensOneHaikuRequest, true)
+ ctx = context.WithValue(ctx, ctxkey.ThinkingEnabled, true)
+ ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyAccountID, int64(321))
+ ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyGroupID, int64(654))
+ ctx = context.WithValue(ctx, ctxkey.SingleAccountRetry, true)
+ ctx = context.WithValue(ctx, ctxkey.AccountSwitchCount, int64(3))
+
+ isHaiku, ok := IsMaxTokensOneHaikuRequestFromContext(ctx)
+ require.True(t, ok)
+ require.True(t, isHaiku)
+
+ thinking, ok := ThinkingEnabledFromContext(ctx)
+ require.True(t, ok)
+ require.True(t, thinking)
+
+ accountID, ok := PrefetchedStickyAccountIDFromContext(ctx)
+ require.True(t, ok)
+ require.Equal(t, int64(321), accountID)
+
+ groupID, ok := PrefetchedStickyGroupIDFromContext(ctx)
+ require.True(t, ok)
+ require.Equal(t, int64(654), groupID)
+
+ singleRetry, ok := SingleAccountRetryFromContext(ctx)
+ require.True(t, ok)
+ require.True(t, singleRetry)
+
+ switchCount, ok := AccountSwitchCountFromContext(ctx)
+ require.True(t, ok)
+ require.Equal(t, 3, switchCount)
+
+ afterHaiku, afterThinking, afterAccount, afterGroup, afterSingleRetry, afterSwitchCount := RequestMetadataFallbackStats()
+ require.Equal(t, beforeHaiku+1, afterHaiku)
+ require.Equal(t, beforeThinking+1, afterThinking)
+ require.Equal(t, beforeAccount+1, afterAccount)
+ require.Equal(t, beforeGroup+1, afterGroup)
+ require.Equal(t, beforeSingleRetry+1, afterSingleRetry)
+ require.Equal(t, beforeSwitchCount+1, afterSwitchCount)
+}
+
+func TestRequestMetadataRead_PreferMetadataOverLegacy(t *testing.T) {
+ ctx := context.WithValue(context.Background(), ctxkey.ThinkingEnabled, false)
+ ctx = WithThinkingEnabled(ctx, true, false)
+
+ thinking, ok := ThinkingEnabledFromContext(ctx)
+ require.True(t, ok)
+ require.True(t, thinking)
+ require.Equal(t, false, ctx.Value(ctxkey.ThinkingEnabled))
+}
diff --git a/backend/internal/service/response_header_filter.go b/backend/internal/service/response_header_filter.go
new file mode 100644
index 00000000..81012b01
--- /dev/null
+++ b/backend/internal/service/response_header_filter.go
@@ -0,0 +1,13 @@
+package service
+
+import (
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
+)
+
+func compileResponseHeaderFilter(cfg *config.Config) *responseheaders.CompiledHeaderFilter {
+ if cfg == nil {
+ return nil
+ }
+ return responseheaders.CompileHeaderFilter(cfg.Security.ResponseHeaders)
+}
diff --git a/backend/internal/service/rpm_cache.go b/backend/internal/service/rpm_cache.go
new file mode 100644
index 00000000..07036219
--- /dev/null
+++ b/backend/internal/service/rpm_cache.go
@@ -0,0 +1,17 @@
+package service
+
+import "context"
+
+// RPMCache RPM 计数器缓存接口
+// 用于 Anthropic OAuth/SetupToken 账号的每分钟请求数限制
+type RPMCache interface {
+ // IncrementRPM 原子递增并返回当前分钟的计数
+ // 使用 Redis 服务器时间确定 minute key,避免多实例时钟偏差
+ IncrementRPM(ctx context.Context, accountID int64) (count int, err error)
+
+ // GetRPM 获取当前分钟的 RPM 计数
+ GetRPM(ctx context.Context, accountID int64) (count int, err error)
+
+ // GetRPMBatch 批量获取多个账号的 RPM 计数(使用 Pipeline)
+ GetRPMBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error)
+}
diff --git a/backend/internal/service/scheduler_snapshot_service.go b/backend/internal/service/scheduler_snapshot_service.go
index 4d95743c..9f8fa14a 100644
--- a/backend/internal/service/scheduler_snapshot_service.go
+++ b/backend/internal/service/scheduler_snapshot_service.go
@@ -305,13 +305,78 @@ func (s *SchedulerSnapshotService) handleBulkAccountEvent(ctx context.Context, p
if payload == nil {
return nil
}
- ids := parseInt64Slice(payload["account_ids"])
- for _, id := range ids {
- if err := s.handleAccountEvent(ctx, &id, payload); err != nil {
- return err
+ if s.accountRepo == nil {
+ return nil
+ }
+
+ rawIDs := parseInt64Slice(payload["account_ids"])
+ if len(rawIDs) == 0 {
+ return nil
+ }
+
+ ids := make([]int64, 0, len(rawIDs))
+ seen := make(map[int64]struct{}, len(rawIDs))
+ for _, id := range rawIDs {
+ if id <= 0 {
+ continue
+ }
+ if _, exists := seen[id]; exists {
+ continue
+ }
+ seen[id] = struct{}{}
+ ids = append(ids, id)
+ }
+ if len(ids) == 0 {
+ return nil
+ }
+
+ preloadGroupIDs := parseInt64Slice(payload["group_ids"])
+ accounts, err := s.accountRepo.GetByIDs(ctx, ids)
+ if err != nil {
+ return err
+ }
+
+ found := make(map[int64]struct{}, len(accounts))
+ rebuildGroupSet := make(map[int64]struct{}, len(preloadGroupIDs))
+ for _, gid := range preloadGroupIDs {
+ if gid > 0 {
+ rebuildGroupSet[gid] = struct{}{}
}
}
- return nil
+
+ for _, account := range accounts {
+ if account == nil || account.ID <= 0 {
+ continue
+ }
+ found[account.ID] = struct{}{}
+ if s.cache != nil {
+ if err := s.cache.SetAccount(ctx, account); err != nil {
+ return err
+ }
+ }
+ for _, gid := range account.GroupIDs {
+ if gid > 0 {
+ rebuildGroupSet[gid] = struct{}{}
+ }
+ }
+ }
+
+ if s.cache != nil {
+ for _, id := range ids {
+ if _, ok := found[id]; ok {
+ continue
+ }
+ if err := s.cache.DeleteAccount(ctx, id); err != nil {
+ return err
+ }
+ }
+ }
+
+ rebuildGroupIDs := make([]int64, 0, len(rebuildGroupSet))
+ for gid := range rebuildGroupSet {
+ rebuildGroupIDs = append(rebuildGroupIDs, gid)
+ }
+ return s.rebuildByGroupIDs(ctx, rebuildGroupIDs, "account_bulk_change")
}
func (s *SchedulerSnapshotService) handleAccountEvent(ctx context.Context, accountID *int64, payload map[string]any) error {
diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go
index 222a4758..c708e061 100644
--- a/backend/internal/service/setting_service.go
+++ b/backend/internal/service/setting_service.go
@@ -7,16 +7,30 @@ import (
"encoding/json"
"errors"
"fmt"
+ "log/slog"
"strconv"
"strings"
+ "sync/atomic"
+ "time"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
+ "golang.org/x/sync/singleflight"
)
var (
- ErrRegistrationDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled")
- ErrSettingNotFound = infraerrors.NotFound("SETTING_NOT_FOUND", "setting not found")
+ ErrRegistrationDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled")
+ ErrSettingNotFound = infraerrors.NotFound("SETTING_NOT_FOUND", "setting not found")
+ ErrSoraS3ProfileNotFound = infraerrors.NotFound("SORA_S3_PROFILE_NOT_FOUND", "sora s3 profile not found")
+ ErrSoraS3ProfileExists = infraerrors.Conflict("SORA_S3_PROFILE_EXISTS", "sora s3 profile already exists")
+ ErrDefaultSubGroupInvalid = infraerrors.BadRequest(
+ "DEFAULT_SUBSCRIPTION_GROUP_INVALID",
+ "default subscription group must exist and be subscription type",
+ )
+ ErrDefaultSubGroupDuplicate = infraerrors.BadRequest(
+ "DEFAULT_SUBSCRIPTION_GROUP_DUPLICATE",
+ "default subscription group cannot be duplicated",
+ )
)
type SettingRepository interface {
@@ -29,12 +43,40 @@ type SettingRepository interface {
Delete(ctx context.Context, key string) error
}
+// cachedMinVersion 缓存最低 Claude Code 版本号(进程内缓存,60s TTL)
+type cachedMinVersion struct {
+ value string // 空字符串 = 不检查
+ expiresAt int64 // unix nano
+}
+
+// minVersionCache 最低版本号进程内缓存
+var minVersionCache atomic.Value // *cachedMinVersion
+
+// minVersionSF 防止缓存过期时 thundering herd
+var minVersionSF singleflight.Group
+
+// minVersionCacheTTL 缓存有效期
+const minVersionCacheTTL = 60 * time.Second
+
+// minVersionErrorTTL DB 错误时的短缓存,快速重试
+const minVersionErrorTTL = 5 * time.Second
+
+// minVersionDBTimeout singleflight 内 DB 查询超时,独立于请求 context
+const minVersionDBTimeout = 5 * time.Second
+
+// DefaultSubscriptionGroupReader validates group references used by default subscriptions.
+type DefaultSubscriptionGroupReader interface {
+ GetByID(ctx context.Context, id int64) (*Group, error)
+}
+
// SettingService 系统设置服务
type SettingService struct {
- settingRepo SettingRepository
- cfg *config.Config
- onUpdate func() // Callback when settings are updated (for cache invalidation)
- version string // Application version
+ settingRepo SettingRepository
+ defaultSubGroupReader DefaultSubscriptionGroupReader
+ cfg *config.Config
+ onUpdate func() // Callback when settings are updated (for cache invalidation)
+ onS3Update func() // Callback when Sora S3 settings are updated
+ version string // Application version
}
// NewSettingService 创建系统设置服务实例
@@ -45,6 +87,11 @@ func NewSettingService(settingRepo SettingRepository, cfg *config.Config) *Setti
}
}
+// SetDefaultSubscriptionGroupReader injects an optional group reader for default subscription validation.
+func (s *SettingService) SetDefaultSubscriptionGroupReader(reader DefaultSubscriptionGroupReader) {
+ s.defaultSubGroupReader = reader
+}
+
// GetAllSettings 获取所有系统设置
func (s *SettingService) GetAllSettings(ctx context.Context) (*SystemSettings, error) {
settings, err := s.settingRepo.GetAll(ctx)
@@ -76,6 +123,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
SettingKeyHideCcsImportButton,
SettingKeyPurchaseSubscriptionEnabled,
SettingKeyPurchaseSubscriptionURL,
+ SettingKeySoraClientEnabled,
SettingKeyLinuxDoConnectEnabled,
}
@@ -114,6 +162,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
HideCcsImportButton: settings[SettingKeyHideCcsImportButton] == "true",
PurchaseSubscriptionEnabled: settings[SettingKeyPurchaseSubscriptionEnabled] == "true",
PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]),
+ SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true",
LinuxDoOAuthEnabled: linuxDoEnabled,
}, nil
}
@@ -124,6 +173,11 @@ func (s *SettingService) SetOnUpdateCallback(callback func()) {
s.onUpdate = callback
}
+// SetOnS3UpdateCallback 设置 Sora S3 配置变更时的回调函数(用于刷新 S3 客户端缓存)。
+func (s *SettingService) SetOnS3UpdateCallback(callback func()) {
+ s.onS3Update = callback
+}
+
// SetVersion sets the application version for injection into public settings
func (s *SettingService) SetVersion(version string) {
s.version = version
@@ -157,6 +211,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
HideCcsImportButton bool `json:"hide_ccs_import_button"`
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
PurchaseSubscriptionURL string `json:"purchase_subscription_url,omitempty"`
+ SoraClientEnabled bool `json:"sora_client_enabled"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
Version string `json:"version,omitempty"`
}{
@@ -178,6 +233,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
HideCcsImportButton: settings.HideCcsImportButton,
PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
+ SoraClientEnabled: settings.SoraClientEnabled,
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
Version: s.version,
}, nil
@@ -185,6 +241,10 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
// UpdateSettings 更新系统设置
func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSettings) error {
+ if err := s.validateDefaultSubscriptionGroups(ctx, settings.DefaultSubscriptions); err != nil {
+ return err
+ }
+
updates := make(map[string]string)
// 注册设置
@@ -232,10 +292,16 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates[SettingKeyHideCcsImportButton] = strconv.FormatBool(settings.HideCcsImportButton)
updates[SettingKeyPurchaseSubscriptionEnabled] = strconv.FormatBool(settings.PurchaseSubscriptionEnabled)
updates[SettingKeyPurchaseSubscriptionURL] = strings.TrimSpace(settings.PurchaseSubscriptionURL)
+ updates[SettingKeySoraClientEnabled] = strconv.FormatBool(settings.SoraClientEnabled)
// 默认配置
updates[SettingKeyDefaultConcurrency] = strconv.Itoa(settings.DefaultConcurrency)
updates[SettingKeyDefaultBalance] = strconv.FormatFloat(settings.DefaultBalance, 'f', 8, 64)
+ defaultSubsJSON, err := json.Marshal(settings.DefaultSubscriptions)
+ if err != nil {
+ return fmt.Errorf("marshal default subscriptions: %w", err)
+ }
+ updates[SettingKeyDefaultSubscriptions] = string(defaultSubsJSON)
// Model fallback configuration
updates[SettingKeyEnableModelFallback] = strconv.FormatBool(settings.EnableModelFallback)
@@ -256,13 +322,63 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates[SettingKeyOpsMetricsIntervalSeconds] = strconv.Itoa(settings.OpsMetricsIntervalSeconds)
}
- err := s.settingRepo.SetMultiple(ctx, updates)
- if err == nil && s.onUpdate != nil {
- s.onUpdate() // Invalidate cache after settings update
+ // Claude Code version check
+ updates[SettingKeyMinClaudeCodeVersion] = settings.MinClaudeCodeVersion
+
+ err = s.settingRepo.SetMultiple(ctx, updates)
+ if err == nil {
+ // 先使 inflight singleflight 失效,再刷新缓存,缩小旧值覆盖新值的竞态窗口
+ minVersionSF.Forget("min_version")
+ minVersionCache.Store(&cachedMinVersion{
+ value: settings.MinClaudeCodeVersion,
+ expiresAt: time.Now().Add(minVersionCacheTTL).UnixNano(),
+ })
+ if s.onUpdate != nil {
+ s.onUpdate() // Invalidate cache after settings update
+ }
}
return err
}
+func (s *SettingService) validateDefaultSubscriptionGroups(ctx context.Context, items []DefaultSubscriptionSetting) error {
+ if len(items) == 0 {
+ return nil
+ }
+
+ checked := make(map[int64]struct{}, len(items))
+ for _, item := range items {
+ if item.GroupID <= 0 {
+ continue
+ }
+ if _, ok := checked[item.GroupID]; ok {
+ return ErrDefaultSubGroupDuplicate.WithMetadata(map[string]string{
+ "group_id": strconv.FormatInt(item.GroupID, 10),
+ })
+ }
+ checked[item.GroupID] = struct{}{}
+ if s.defaultSubGroupReader == nil {
+ continue
+ }
+
+ group, err := s.defaultSubGroupReader.GetByID(ctx, item.GroupID)
+ if err != nil {
+ if errors.Is(err, ErrGroupNotFound) {
+ return ErrDefaultSubGroupInvalid.WithMetadata(map[string]string{
+ "group_id": strconv.FormatInt(item.GroupID, 10),
+ })
+ }
+ return fmt.Errorf("get default subscription group %d: %w", item.GroupID, err)
+ }
+ if !group.IsSubscriptionType() {
+ return ErrDefaultSubGroupInvalid.WithMetadata(map[string]string{
+ "group_id": strconv.FormatInt(item.GroupID, 10),
+ })
+ }
+ }
+
+ return nil
+}
+
// IsRegistrationEnabled 检查是否开放注册
func (s *SettingService) IsRegistrationEnabled(ctx context.Context) bool {
value, err := s.settingRepo.GetValue(ctx, SettingKeyRegistrationEnabled)
@@ -362,6 +478,15 @@ func (s *SettingService) GetDefaultBalance(ctx context.Context) float64 {
return s.cfg.Default.UserBalance
}
+// GetDefaultSubscriptions 获取新用户默认订阅配置列表。
+func (s *SettingService) GetDefaultSubscriptions(ctx context.Context) []DefaultSubscriptionSetting {
+ value, err := s.settingRepo.GetValue(ctx, SettingKeyDefaultSubscriptions)
+ if err != nil {
+ return nil
+ }
+ return parseDefaultSubscriptions(value)
+}
+
// InitializeDefaultSettings 初始化默认设置
func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
// 检查是否已有设置
@@ -383,8 +508,10 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
SettingKeySiteLogo: "",
SettingKeyPurchaseSubscriptionEnabled: "false",
SettingKeyPurchaseSubscriptionURL: "",
+ SettingKeySoraClientEnabled: "false",
SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64),
+ SettingKeyDefaultSubscriptions: "[]",
SettingKeySMTPPort: "587",
SettingKeySMTPUseTLS: "false",
// Model fallback defaults
@@ -402,6 +529,9 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
SettingKeyOpsRealtimeMonitoringEnabled: "true",
SettingKeyOpsQueryModeDefault: "auto",
SettingKeyOpsMetricsIntervalSeconds: "60",
+
+ // Claude Code version check (default: empty = disabled)
+ SettingKeyMinClaudeCodeVersion: "",
}
return s.settingRepo.SetMultiple(ctx, defaults)
@@ -436,6 +566,7 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
HideCcsImportButton: settings[SettingKeyHideCcsImportButton] == "true",
PurchaseSubscriptionEnabled: settings[SettingKeyPurchaseSubscriptionEnabled] == "true",
PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]),
+ SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true",
}
// 解析整数类型
@@ -457,6 +588,7 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
} else {
result.DefaultBalance = s.cfg.Default.UserBalance
}
+ result.DefaultSubscriptions = parseDefaultSubscriptions(settings[SettingKeyDefaultSubscriptions])
// 敏感信息直接返回,方便测试连接时使用
result.SMTPPassword = settings[SettingKeySMTPPassword]
@@ -526,6 +658,9 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
}
}
+ // Claude Code version check
+ result.MinClaudeCodeVersion = settings[SettingKeyMinClaudeCodeVersion]
+
return result
}
@@ -538,6 +673,31 @@ func isFalseSettingValue(value string) bool {
}
}
+func parseDefaultSubscriptions(raw string) []DefaultSubscriptionSetting {
+ raw = strings.TrimSpace(raw)
+ if raw == "" {
+ return nil
+ }
+
+ var items []DefaultSubscriptionSetting
+ if err := json.Unmarshal([]byte(raw), &items); err != nil {
+ return nil
+ }
+
+ normalized := make([]DefaultSubscriptionSetting, 0, len(items))
+ for _, item := range items {
+ if item.GroupID <= 0 || item.ValidityDays <= 0 {
+ continue
+ }
+ if item.ValidityDays > MaxValidityDays {
+ item.ValidityDays = MaxValidityDays
+ }
+ normalized = append(normalized, item)
+ }
+
+ return normalized
+}
+
// getStringOrDefault 获取字符串值或默认值
func (s *SettingService) getStringOrDefault(settings map[string]string, key, defaultValue string) string {
if value, ok := settings[key]; ok && value != "" {
@@ -823,6 +983,49 @@ func (s *SettingService) GetStreamTimeoutSettings(ctx context.Context) (*StreamT
return &settings, nil
}
+// GetMinClaudeCodeVersion 获取最低 Claude Code 版本号要求
+// 使用进程内 atomic.Value 缓存,60 秒 TTL,热路径零锁开销
+// singleflight 防止缓存过期时 thundering herd
+// 返回空字符串表示不做版本检查
+func (s *SettingService) GetMinClaudeCodeVersion(ctx context.Context) string {
+ if cached, ok := minVersionCache.Load().(*cachedMinVersion); ok {
+ if time.Now().UnixNano() < cached.expiresAt {
+ return cached.value
+ }
+ }
+ // singleflight: 同一时刻只有一个 goroutine 查询 DB,其余复用结果
+ result, _, _ := minVersionSF.Do("min_version", func() (any, error) {
+ // 二次检查,避免排队的 goroutine 重复查询
+ if cached, ok := minVersionCache.Load().(*cachedMinVersion); ok {
+ if time.Now().UnixNano() < cached.expiresAt {
+ return cached.value, nil
+ }
+ }
+ // 使用独立 context:断开请求取消链,避免客户端断连导致空值被长期缓存
+ dbCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), minVersionDBTimeout)
+ defer cancel()
+ value, err := s.settingRepo.GetValue(dbCtx, SettingKeyMinClaudeCodeVersion)
+ if err != nil {
+ // fail-open: DB 错误时不阻塞请求,但记录日志并使用短 TTL 快速重试
+ slog.Warn("failed to get min claude code version setting, skipping version check", "error", err)
+ minVersionCache.Store(&cachedMinVersion{
+ value: "",
+ expiresAt: time.Now().Add(minVersionErrorTTL).UnixNano(),
+ })
+ return "", nil
+ }
+ minVersionCache.Store(&cachedMinVersion{
+ value: value,
+ expiresAt: time.Now().Add(minVersionCacheTTL).UnixNano(),
+ })
+ return value, nil
+ })
+ if s, ok := result.(string); ok {
+ return s
+ }
+ return ""
+}
+
// SetStreamTimeoutSettings 设置流超时处理配置
func (s *SettingService) SetStreamTimeoutSettings(ctx context.Context, settings *StreamTimeoutSettings) error {
if settings == nil {
@@ -854,3 +1057,607 @@ func (s *SettingService) SetStreamTimeoutSettings(ctx context.Context, settings
return s.settingRepo.Set(ctx, SettingKeyStreamTimeoutSettings, string(data))
}
+
+type soraS3ProfilesStore struct {
+ ActiveProfileID string `json:"active_profile_id"`
+ Items []soraS3ProfileStoreItem `json:"items"`
+}
+
+type soraS3ProfileStoreItem struct {
+ ProfileID string `json:"profile_id"`
+ Name string `json:"name"`
+ Enabled bool `json:"enabled"`
+ Endpoint string `json:"endpoint"`
+ Region string `json:"region"`
+ Bucket string `json:"bucket"`
+ AccessKeyID string `json:"access_key_id"`
+ SecretAccessKey string `json:"secret_access_key"`
+ Prefix string `json:"prefix"`
+ ForcePathStyle bool `json:"force_path_style"`
+ CDNURL string `json:"cdn_url"`
+ DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
+ UpdatedAt string `json:"updated_at"`
+}
+
+// GetSoraS3Settings 获取 Sora S3 存储配置(兼容旧单配置语义:返回当前激活配置)
+func (s *SettingService) GetSoraS3Settings(ctx context.Context) (*SoraS3Settings, error) {
+ profiles, err := s.ListSoraS3Profiles(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ activeProfile := pickActiveSoraS3Profile(profiles.Items, profiles.ActiveProfileID)
+ if activeProfile == nil {
+ return &SoraS3Settings{}, nil
+ }
+
+ return &SoraS3Settings{
+ Enabled: activeProfile.Enabled,
+ Endpoint: activeProfile.Endpoint,
+ Region: activeProfile.Region,
+ Bucket: activeProfile.Bucket,
+ AccessKeyID: activeProfile.AccessKeyID,
+ SecretAccessKey: activeProfile.SecretAccessKey,
+ SecretAccessKeyConfigured: activeProfile.SecretAccessKeyConfigured,
+ Prefix: activeProfile.Prefix,
+ ForcePathStyle: activeProfile.ForcePathStyle,
+ CDNURL: activeProfile.CDNURL,
+ DefaultStorageQuotaBytes: activeProfile.DefaultStorageQuotaBytes,
+ }, nil
+}
+
+// SetSoraS3Settings 更新 Sora S3 存储配置(兼容旧单配置语义:写入当前激活配置)
+func (s *SettingService) SetSoraS3Settings(ctx context.Context, settings *SoraS3Settings) error {
+ if settings == nil {
+ return fmt.Errorf("settings cannot be nil")
+ }
+
+ store, err := s.loadSoraS3ProfilesStore(ctx)
+ if err != nil {
+ return err
+ }
+
+ now := time.Now().UTC().Format(time.RFC3339)
+ activeIndex := findSoraS3ProfileIndex(store.Items, store.ActiveProfileID)
+ if activeIndex < 0 {
+ activeID := "default"
+ if hasSoraS3ProfileID(store.Items, activeID) {
+ activeID = fmt.Sprintf("default-%d", time.Now().Unix())
+ }
+ store.Items = append(store.Items, soraS3ProfileStoreItem{
+ ProfileID: activeID,
+ Name: "Default",
+ UpdatedAt: now,
+ })
+ store.ActiveProfileID = activeID
+ activeIndex = len(store.Items) - 1
+ }
+
+ active := store.Items[activeIndex]
+ active.Enabled = settings.Enabled
+ active.Endpoint = strings.TrimSpace(settings.Endpoint)
+ active.Region = strings.TrimSpace(settings.Region)
+ active.Bucket = strings.TrimSpace(settings.Bucket)
+ active.AccessKeyID = strings.TrimSpace(settings.AccessKeyID)
+ active.Prefix = strings.TrimSpace(settings.Prefix)
+ active.ForcePathStyle = settings.ForcePathStyle
+ active.CDNURL = strings.TrimSpace(settings.CDNURL)
+ active.DefaultStorageQuotaBytes = maxInt64(settings.DefaultStorageQuotaBytes, 0)
+ if settings.SecretAccessKey != "" {
+ active.SecretAccessKey = settings.SecretAccessKey
+ }
+ active.UpdatedAt = now
+ store.Items[activeIndex] = active
+
+ return s.persistSoraS3ProfilesStore(ctx, store)
+}
+
+// ListSoraS3Profiles 获取 Sora S3 多配置列表
+func (s *SettingService) ListSoraS3Profiles(ctx context.Context) (*SoraS3ProfileList, error) {
+ store, err := s.loadSoraS3ProfilesStore(ctx)
+ if err != nil {
+ return nil, err
+ }
+ return convertSoraS3ProfilesStore(store), nil
+}
+
+// CreateSoraS3Profile 创建 Sora S3 配置
+func (s *SettingService) CreateSoraS3Profile(ctx context.Context, profile *SoraS3Profile, setActive bool) (*SoraS3Profile, error) {
+ if profile == nil {
+ return nil, fmt.Errorf("profile cannot be nil")
+ }
+
+ profileID := strings.TrimSpace(profile.ProfileID)
+ if profileID == "" {
+ return nil, infraerrors.BadRequest("SORA_S3_PROFILE_ID_REQUIRED", "profile_id is required")
+ }
+ name := strings.TrimSpace(profile.Name)
+ if name == "" {
+ return nil, infraerrors.BadRequest("SORA_S3_PROFILE_NAME_REQUIRED", "name is required")
+ }
+
+ store, err := s.loadSoraS3ProfilesStore(ctx)
+ if err != nil {
+ return nil, err
+ }
+ if hasSoraS3ProfileID(store.Items, profileID) {
+ return nil, ErrSoraS3ProfileExists
+ }
+
+ now := time.Now().UTC().Format(time.RFC3339)
+ store.Items = append(store.Items, soraS3ProfileStoreItem{
+ ProfileID: profileID,
+ Name: name,
+ Enabled: profile.Enabled,
+ Endpoint: strings.TrimSpace(profile.Endpoint),
+ Region: strings.TrimSpace(profile.Region),
+ Bucket: strings.TrimSpace(profile.Bucket),
+ AccessKeyID: strings.TrimSpace(profile.AccessKeyID),
+ SecretAccessKey: profile.SecretAccessKey,
+ Prefix: strings.TrimSpace(profile.Prefix),
+ ForcePathStyle: profile.ForcePathStyle,
+ CDNURL: strings.TrimSpace(profile.CDNURL),
+ DefaultStorageQuotaBytes: maxInt64(profile.DefaultStorageQuotaBytes, 0),
+ UpdatedAt: now,
+ })
+
+ if setActive || store.ActiveProfileID == "" {
+ store.ActiveProfileID = profileID
+ }
+
+ if err := s.persistSoraS3ProfilesStore(ctx, store); err != nil {
+ return nil, err
+ }
+
+ profiles := convertSoraS3ProfilesStore(store)
+ created := findSoraS3ProfileByID(profiles.Items, profileID)
+ if created == nil {
+ return nil, ErrSoraS3ProfileNotFound
+ }
+ return created, nil
+}
+
+// UpdateSoraS3Profile 更新 Sora S3 配置
+func (s *SettingService) UpdateSoraS3Profile(ctx context.Context, profileID string, profile *SoraS3Profile) (*SoraS3Profile, error) {
+ if profile == nil {
+ return nil, fmt.Errorf("profile cannot be nil")
+ }
+
+ targetID := strings.TrimSpace(profileID)
+ if targetID == "" {
+ return nil, infraerrors.BadRequest("SORA_S3_PROFILE_ID_REQUIRED", "profile_id is required")
+ }
+
+ store, err := s.loadSoraS3ProfilesStore(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ targetIndex := findSoraS3ProfileIndex(store.Items, targetID)
+ if targetIndex < 0 {
+ return nil, ErrSoraS3ProfileNotFound
+ }
+
+ target := store.Items[targetIndex]
+ name := strings.TrimSpace(profile.Name)
+ if name == "" {
+ return nil, infraerrors.BadRequest("SORA_S3_PROFILE_NAME_REQUIRED", "name is required")
+ }
+ target.Name = name
+ target.Enabled = profile.Enabled
+ target.Endpoint = strings.TrimSpace(profile.Endpoint)
+ target.Region = strings.TrimSpace(profile.Region)
+ target.Bucket = strings.TrimSpace(profile.Bucket)
+ target.AccessKeyID = strings.TrimSpace(profile.AccessKeyID)
+ target.Prefix = strings.TrimSpace(profile.Prefix)
+ target.ForcePathStyle = profile.ForcePathStyle
+ target.CDNURL = strings.TrimSpace(profile.CDNURL)
+ target.DefaultStorageQuotaBytes = maxInt64(profile.DefaultStorageQuotaBytes, 0)
+ if profile.SecretAccessKey != "" {
+ target.SecretAccessKey = profile.SecretAccessKey
+ }
+ target.UpdatedAt = time.Now().UTC().Format(time.RFC3339)
+ store.Items[targetIndex] = target
+
+ if err := s.persistSoraS3ProfilesStore(ctx, store); err != nil {
+ return nil, err
+ }
+
+ profiles := convertSoraS3ProfilesStore(store)
+ updated := findSoraS3ProfileByID(profiles.Items, targetID)
+ if updated == nil {
+ return nil, ErrSoraS3ProfileNotFound
+ }
+ return updated, nil
+}
+
+// DeleteSoraS3Profile 删除 Sora S3 配置
+func (s *SettingService) DeleteSoraS3Profile(ctx context.Context, profileID string) error {
+ targetID := strings.TrimSpace(profileID)
+ if targetID == "" {
+ return infraerrors.BadRequest("SORA_S3_PROFILE_ID_REQUIRED", "profile_id is required")
+ }
+
+ store, err := s.loadSoraS3ProfilesStore(ctx)
+ if err != nil {
+ return err
+ }
+
+ targetIndex := findSoraS3ProfileIndex(store.Items, targetID)
+ if targetIndex < 0 {
+ return ErrSoraS3ProfileNotFound
+ }
+
+ store.Items = append(store.Items[:targetIndex], store.Items[targetIndex+1:]...)
+ if store.ActiveProfileID == targetID {
+ store.ActiveProfileID = ""
+ if len(store.Items) > 0 {
+ store.ActiveProfileID = store.Items[0].ProfileID
+ }
+ }
+
+ return s.persistSoraS3ProfilesStore(ctx, store)
+}
+
+// SetActiveSoraS3Profile 设置激活的 Sora S3 配置
+func (s *SettingService) SetActiveSoraS3Profile(ctx context.Context, profileID string) (*SoraS3Profile, error) {
+ targetID := strings.TrimSpace(profileID)
+ if targetID == "" {
+ return nil, infraerrors.BadRequest("SORA_S3_PROFILE_ID_REQUIRED", "profile_id is required")
+ }
+
+ store, err := s.loadSoraS3ProfilesStore(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ targetIndex := findSoraS3ProfileIndex(store.Items, targetID)
+ if targetIndex < 0 {
+ return nil, ErrSoraS3ProfileNotFound
+ }
+
+ store.ActiveProfileID = targetID
+ store.Items[targetIndex].UpdatedAt = time.Now().UTC().Format(time.RFC3339)
+ if err := s.persistSoraS3ProfilesStore(ctx, store); err != nil {
+ return nil, err
+ }
+
+ profiles := convertSoraS3ProfilesStore(store)
+ active := pickActiveSoraS3Profile(profiles.Items, profiles.ActiveProfileID)
+ if active == nil {
+ return nil, ErrSoraS3ProfileNotFound
+ }
+ return active, nil
+}
+
+func (s *SettingService) loadSoraS3ProfilesStore(ctx context.Context) (*soraS3ProfilesStore, error) {
+ raw, err := s.settingRepo.GetValue(ctx, SettingKeySoraS3Profiles)
+ if err == nil {
+ trimmed := strings.TrimSpace(raw)
+ if trimmed == "" {
+ return &soraS3ProfilesStore{}, nil
+ }
+ var store soraS3ProfilesStore
+ if unmarshalErr := json.Unmarshal([]byte(trimmed), &store); unmarshalErr != nil {
+ legacy, legacyErr := s.getLegacySoraS3Settings(ctx)
+ if legacyErr != nil {
+ return nil, fmt.Errorf("unmarshal sora s3 profiles: %w", unmarshalErr)
+ }
+ if isEmptyLegacySoraS3Settings(legacy) {
+ return &soraS3ProfilesStore{}, nil
+ }
+ now := time.Now().UTC().Format(time.RFC3339)
+ return &soraS3ProfilesStore{
+ ActiveProfileID: "default",
+ Items: []soraS3ProfileStoreItem{
+ {
+ ProfileID: "default",
+ Name: "Default",
+ Enabled: legacy.Enabled,
+ Endpoint: strings.TrimSpace(legacy.Endpoint),
+ Region: strings.TrimSpace(legacy.Region),
+ Bucket: strings.TrimSpace(legacy.Bucket),
+ AccessKeyID: strings.TrimSpace(legacy.AccessKeyID),
+ SecretAccessKey: legacy.SecretAccessKey,
+ Prefix: strings.TrimSpace(legacy.Prefix),
+ ForcePathStyle: legacy.ForcePathStyle,
+ CDNURL: strings.TrimSpace(legacy.CDNURL),
+ DefaultStorageQuotaBytes: maxInt64(legacy.DefaultStorageQuotaBytes, 0),
+ UpdatedAt: now,
+ },
+ },
+ }, nil
+ }
+ normalized := normalizeSoraS3ProfilesStore(store)
+ return &normalized, nil
+ }
+
+ if !errors.Is(err, ErrSettingNotFound) {
+ return nil, fmt.Errorf("get sora s3 profiles: %w", err)
+ }
+
+ legacy, legacyErr := s.getLegacySoraS3Settings(ctx)
+ if legacyErr != nil {
+ return nil, legacyErr
+ }
+ if isEmptyLegacySoraS3Settings(legacy) {
+ return &soraS3ProfilesStore{}, nil
+ }
+
+ now := time.Now().UTC().Format(time.RFC3339)
+ return &soraS3ProfilesStore{
+ ActiveProfileID: "default",
+ Items: []soraS3ProfileStoreItem{
+ {
+ ProfileID: "default",
+ Name: "Default",
+ Enabled: legacy.Enabled,
+ Endpoint: strings.TrimSpace(legacy.Endpoint),
+ Region: strings.TrimSpace(legacy.Region),
+ Bucket: strings.TrimSpace(legacy.Bucket),
+ AccessKeyID: strings.TrimSpace(legacy.AccessKeyID),
+ SecretAccessKey: legacy.SecretAccessKey,
+ Prefix: strings.TrimSpace(legacy.Prefix),
+ ForcePathStyle: legacy.ForcePathStyle,
+ CDNURL: strings.TrimSpace(legacy.CDNURL),
+ DefaultStorageQuotaBytes: maxInt64(legacy.DefaultStorageQuotaBytes, 0),
+ UpdatedAt: now,
+ },
+ },
+ }, nil
+}
+
+func (s *SettingService) persistSoraS3ProfilesStore(ctx context.Context, store *soraS3ProfilesStore) error {
+ if store == nil {
+ return fmt.Errorf("sora s3 profiles store cannot be nil")
+ }
+
+ normalized := normalizeSoraS3ProfilesStore(*store)
+ data, err := json.Marshal(normalized)
+ if err != nil {
+ return fmt.Errorf("marshal sora s3 profiles: %w", err)
+ }
+
+ updates := map[string]string{
+ SettingKeySoraS3Profiles: string(data),
+ }
+
+ active := pickActiveSoraS3ProfileFromStore(normalized.Items, normalized.ActiveProfileID)
+ if active == nil {
+ updates[SettingKeySoraS3Enabled] = "false"
+ updates[SettingKeySoraS3Endpoint] = ""
+ updates[SettingKeySoraS3Region] = ""
+ updates[SettingKeySoraS3Bucket] = ""
+ updates[SettingKeySoraS3AccessKeyID] = ""
+ updates[SettingKeySoraS3Prefix] = ""
+ updates[SettingKeySoraS3ForcePathStyle] = "false"
+ updates[SettingKeySoraS3CDNURL] = ""
+ updates[SettingKeySoraDefaultStorageQuotaBytes] = "0"
+ updates[SettingKeySoraS3SecretAccessKey] = ""
+ } else {
+ updates[SettingKeySoraS3Enabled] = strconv.FormatBool(active.Enabled)
+ updates[SettingKeySoraS3Endpoint] = strings.TrimSpace(active.Endpoint)
+ updates[SettingKeySoraS3Region] = strings.TrimSpace(active.Region)
+ updates[SettingKeySoraS3Bucket] = strings.TrimSpace(active.Bucket)
+ updates[SettingKeySoraS3AccessKeyID] = strings.TrimSpace(active.AccessKeyID)
+ updates[SettingKeySoraS3Prefix] = strings.TrimSpace(active.Prefix)
+ updates[SettingKeySoraS3ForcePathStyle] = strconv.FormatBool(active.ForcePathStyle)
+ updates[SettingKeySoraS3CDNURL] = strings.TrimSpace(active.CDNURL)
+ updates[SettingKeySoraDefaultStorageQuotaBytes] = strconv.FormatInt(maxInt64(active.DefaultStorageQuotaBytes, 0), 10)
+ updates[SettingKeySoraS3SecretAccessKey] = active.SecretAccessKey
+ }
+
+ if err := s.settingRepo.SetMultiple(ctx, updates); err != nil {
+ return err
+ }
+
+ if s.onUpdate != nil {
+ s.onUpdate()
+ }
+ if s.onS3Update != nil {
+ s.onS3Update()
+ }
+ return nil
+}
+
+func (s *SettingService) getLegacySoraS3Settings(ctx context.Context) (*SoraS3Settings, error) {
+ keys := []string{
+ SettingKeySoraS3Enabled,
+ SettingKeySoraS3Endpoint,
+ SettingKeySoraS3Region,
+ SettingKeySoraS3Bucket,
+ SettingKeySoraS3AccessKeyID,
+ SettingKeySoraS3SecretAccessKey,
+ SettingKeySoraS3Prefix,
+ SettingKeySoraS3ForcePathStyle,
+ SettingKeySoraS3CDNURL,
+ SettingKeySoraDefaultStorageQuotaBytes,
+ }
+
+ values, err := s.settingRepo.GetMultiple(ctx, keys)
+ if err != nil {
+ return nil, fmt.Errorf("get legacy sora s3 settings: %w", err)
+ }
+
+ result := &SoraS3Settings{
+ Enabled: values[SettingKeySoraS3Enabled] == "true",
+ Endpoint: values[SettingKeySoraS3Endpoint],
+ Region: values[SettingKeySoraS3Region],
+ Bucket: values[SettingKeySoraS3Bucket],
+ AccessKeyID: values[SettingKeySoraS3AccessKeyID],
+ SecretAccessKey: values[SettingKeySoraS3SecretAccessKey],
+ SecretAccessKeyConfigured: values[SettingKeySoraS3SecretAccessKey] != "",
+ Prefix: values[SettingKeySoraS3Prefix],
+ ForcePathStyle: values[SettingKeySoraS3ForcePathStyle] == "true",
+ CDNURL: values[SettingKeySoraS3CDNURL],
+ }
+ if v, parseErr := strconv.ParseInt(values[SettingKeySoraDefaultStorageQuotaBytes], 10, 64); parseErr == nil {
+ result.DefaultStorageQuotaBytes = v
+ }
+ return result, nil
+}
+
+func normalizeSoraS3ProfilesStore(store soraS3ProfilesStore) soraS3ProfilesStore {
+ seen := make(map[string]struct{}, len(store.Items))
+ normalized := soraS3ProfilesStore{
+ ActiveProfileID: strings.TrimSpace(store.ActiveProfileID),
+ Items: make([]soraS3ProfileStoreItem, 0, len(store.Items)),
+ }
+ now := time.Now().UTC().Format(time.RFC3339)
+
+ for idx := range store.Items {
+ item := store.Items[idx]
+ item.ProfileID = strings.TrimSpace(item.ProfileID)
+ if item.ProfileID == "" {
+ item.ProfileID = fmt.Sprintf("profile-%d", idx+1)
+ }
+ if _, exists := seen[item.ProfileID]; exists {
+ continue
+ }
+ seen[item.ProfileID] = struct{}{}
+
+ item.Name = strings.TrimSpace(item.Name)
+ if item.Name == "" {
+ item.Name = item.ProfileID
+ }
+ item.Endpoint = strings.TrimSpace(item.Endpoint)
+ item.Region = strings.TrimSpace(item.Region)
+ item.Bucket = strings.TrimSpace(item.Bucket)
+ item.AccessKeyID = strings.TrimSpace(item.AccessKeyID)
+ item.Prefix = strings.TrimSpace(item.Prefix)
+ item.CDNURL = strings.TrimSpace(item.CDNURL)
+ item.DefaultStorageQuotaBytes = maxInt64(item.DefaultStorageQuotaBytes, 0)
+ item.UpdatedAt = strings.TrimSpace(item.UpdatedAt)
+ if item.UpdatedAt == "" {
+ item.UpdatedAt = now
+ }
+ normalized.Items = append(normalized.Items, item)
+ }
+
+ if len(normalized.Items) == 0 {
+ normalized.ActiveProfileID = ""
+ return normalized
+ }
+
+ if findSoraS3ProfileIndex(normalized.Items, normalized.ActiveProfileID) >= 0 {
+ return normalized
+ }
+
+ normalized.ActiveProfileID = normalized.Items[0].ProfileID
+ return normalized
+}
+
+func convertSoraS3ProfilesStore(store *soraS3ProfilesStore) *SoraS3ProfileList {
+ if store == nil {
+ return &SoraS3ProfileList{}
+ }
+ items := make([]SoraS3Profile, 0, len(store.Items))
+ for idx := range store.Items {
+ item := store.Items[idx]
+ items = append(items, SoraS3Profile{
+ ProfileID: item.ProfileID,
+ Name: item.Name,
+ IsActive: item.ProfileID == store.ActiveProfileID,
+ Enabled: item.Enabled,
+ Endpoint: item.Endpoint,
+ Region: item.Region,
+ Bucket: item.Bucket,
+ AccessKeyID: item.AccessKeyID,
+ SecretAccessKey: item.SecretAccessKey,
+ SecretAccessKeyConfigured: item.SecretAccessKey != "",
+ Prefix: item.Prefix,
+ ForcePathStyle: item.ForcePathStyle,
+ CDNURL: item.CDNURL,
+ DefaultStorageQuotaBytes: item.DefaultStorageQuotaBytes,
+ UpdatedAt: item.UpdatedAt,
+ })
+ }
+ return &SoraS3ProfileList{
+ ActiveProfileID: store.ActiveProfileID,
+ Items: items,
+ }
+}
+
+func pickActiveSoraS3Profile(items []SoraS3Profile, activeProfileID string) *SoraS3Profile {
+ for idx := range items {
+ if items[idx].ProfileID == activeProfileID {
+ return &items[idx]
+ }
+ }
+ if len(items) == 0 {
+ return nil
+ }
+ return &items[0]
+}
+
+func findSoraS3ProfileByID(items []SoraS3Profile, profileID string) *SoraS3Profile {
+ for idx := range items {
+ if items[idx].ProfileID == profileID {
+ return &items[idx]
+ }
+ }
+ return nil
+}
+
+func pickActiveSoraS3ProfileFromStore(items []soraS3ProfileStoreItem, activeProfileID string) *soraS3ProfileStoreItem {
+ for idx := range items {
+ if items[idx].ProfileID == activeProfileID {
+ return &items[idx]
+ }
+ }
+ if len(items) == 0 {
+ return nil
+ }
+ return &items[0]
+}
+
+func findSoraS3ProfileIndex(items []soraS3ProfileStoreItem, profileID string) int {
+ for idx := range items {
+ if items[idx].ProfileID == profileID {
+ return idx
+ }
+ }
+ return -1
+}
+
+func hasSoraS3ProfileID(items []soraS3ProfileStoreItem, profileID string) bool {
+ return findSoraS3ProfileIndex(items, profileID) >= 0
+}
+
+func isEmptyLegacySoraS3Settings(settings *SoraS3Settings) bool {
+ if settings == nil {
+ return true
+ }
+ if settings.Enabled {
+ return false
+ }
+ if strings.TrimSpace(settings.Endpoint) != "" {
+ return false
+ }
+ if strings.TrimSpace(settings.Region) != "" {
+ return false
+ }
+ if strings.TrimSpace(settings.Bucket) != "" {
+ return false
+ }
+ if strings.TrimSpace(settings.AccessKeyID) != "" {
+ return false
+ }
+ if settings.SecretAccessKey != "" {
+ return false
+ }
+ if strings.TrimSpace(settings.Prefix) != "" {
+ return false
+ }
+ if strings.TrimSpace(settings.CDNURL) != "" {
+ return false
+ }
+ return settings.DefaultStorageQuotaBytes == 0
+}
+
+func maxInt64(value int64, min int64) int64 {
+ if value < min {
+ return min
+ }
+ return value
+}
diff --git a/backend/internal/service/setting_service_update_test.go b/backend/internal/service/setting_service_update_test.go
new file mode 100644
index 00000000..ec64511f
--- /dev/null
+++ b/backend/internal/service/setting_service_update_test.go
@@ -0,0 +1,182 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "encoding/json"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
+ "github.com/stretchr/testify/require"
+)
+
+type settingUpdateRepoStub struct {
+ updates map[string]string
+}
+
+func (s *settingUpdateRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
+ panic("unexpected Get call")
+}
+
+func (s *settingUpdateRepoStub) GetValue(ctx context.Context, key string) (string, error) {
+ panic("unexpected GetValue call")
+}
+
+func (s *settingUpdateRepoStub) Set(ctx context.Context, key, value string) error {
+ panic("unexpected Set call")
+}
+
+func (s *settingUpdateRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
+ panic("unexpected GetMultiple call")
+}
+
+func (s *settingUpdateRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
+ s.updates = make(map[string]string, len(settings))
+ for k, v := range settings {
+ s.updates[k] = v
+ }
+ return nil
+}
+
+func (s *settingUpdateRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
+ panic("unexpected GetAll call")
+}
+
+func (s *settingUpdateRepoStub) Delete(ctx context.Context, key string) error {
+ panic("unexpected Delete call")
+}
+
+type defaultSubGroupReaderStub struct {
+ byID map[int64]*Group
+ errBy map[int64]error
+ calls []int64
+}
+
+func (s *defaultSubGroupReaderStub) GetByID(ctx context.Context, id int64) (*Group, error) {
+ s.calls = append(s.calls, id)
+ if err, ok := s.errBy[id]; ok {
+ return nil, err
+ }
+ if g, ok := s.byID[id]; ok {
+ return g, nil
+ }
+ return nil, ErrGroupNotFound
+}
+
+func TestSettingService_UpdateSettings_DefaultSubscriptions_ValidGroup(t *testing.T) {
+ repo := &settingUpdateRepoStub{}
+ groupReader := &defaultSubGroupReaderStub{
+ byID: map[int64]*Group{
+ 11: {ID: 11, SubscriptionType: SubscriptionTypeSubscription},
+ },
+ }
+ svc := NewSettingService(repo, &config.Config{})
+ svc.SetDefaultSubscriptionGroupReader(groupReader)
+
+ err := svc.UpdateSettings(context.Background(), &SystemSettings{
+ DefaultSubscriptions: []DefaultSubscriptionSetting{
+ {GroupID: 11, ValidityDays: 30},
+ },
+ })
+ require.NoError(t, err)
+ require.Equal(t, []int64{11}, groupReader.calls)
+
+ raw, ok := repo.updates[SettingKeyDefaultSubscriptions]
+ require.True(t, ok)
+
+ var got []DefaultSubscriptionSetting
+ require.NoError(t, json.Unmarshal([]byte(raw), &got))
+ require.Equal(t, []DefaultSubscriptionSetting{
+ {GroupID: 11, ValidityDays: 30},
+ }, got)
+}
+
+func TestSettingService_UpdateSettings_DefaultSubscriptions_RejectsNonSubscriptionGroup(t *testing.T) {
+ repo := &settingUpdateRepoStub{}
+ groupReader := &defaultSubGroupReaderStub{
+ byID: map[int64]*Group{
+ 12: {ID: 12, SubscriptionType: SubscriptionTypeStandard},
+ },
+ }
+ svc := NewSettingService(repo, &config.Config{})
+ svc.SetDefaultSubscriptionGroupReader(groupReader)
+
+ err := svc.UpdateSettings(context.Background(), &SystemSettings{
+ DefaultSubscriptions: []DefaultSubscriptionSetting{
+ {GroupID: 12, ValidityDays: 7},
+ },
+ })
+ require.Error(t, err)
+ require.Equal(t, "DEFAULT_SUBSCRIPTION_GROUP_INVALID", infraerrors.Reason(err))
+ require.Nil(t, repo.updates)
+}
+
+func TestSettingService_UpdateSettings_DefaultSubscriptions_RejectsNotFoundGroup(t *testing.T) {
+ repo := &settingUpdateRepoStub{}
+ groupReader := &defaultSubGroupReaderStub{
+ errBy: map[int64]error{
+ 13: ErrGroupNotFound,
+ },
+ }
+ svc := NewSettingService(repo, &config.Config{})
+ svc.SetDefaultSubscriptionGroupReader(groupReader)
+
+ err := svc.UpdateSettings(context.Background(), &SystemSettings{
+ DefaultSubscriptions: []DefaultSubscriptionSetting{
+ {GroupID: 13, ValidityDays: 7},
+ },
+ })
+ require.Error(t, err)
+ require.Equal(t, "DEFAULT_SUBSCRIPTION_GROUP_INVALID", infraerrors.Reason(err))
+ require.Equal(t, "13", infraerrors.FromError(err).Metadata["group_id"])
+ require.Nil(t, repo.updates)
+}
+
+func TestSettingService_UpdateSettings_DefaultSubscriptions_RejectsDuplicateGroup(t *testing.T) {
+ repo := &settingUpdateRepoStub{}
+ groupReader := &defaultSubGroupReaderStub{
+ byID: map[int64]*Group{
+ 11: {ID: 11, SubscriptionType: SubscriptionTypeSubscription},
+ },
+ }
+ svc := NewSettingService(repo, &config.Config{})
+ svc.SetDefaultSubscriptionGroupReader(groupReader)
+
+ err := svc.UpdateSettings(context.Background(), &SystemSettings{
+ DefaultSubscriptions: []DefaultSubscriptionSetting{
+ {GroupID: 11, ValidityDays: 30},
+ {GroupID: 11, ValidityDays: 60},
+ },
+ })
+ require.Error(t, err)
+ require.Equal(t, "DEFAULT_SUBSCRIPTION_GROUP_DUPLICATE", infraerrors.Reason(err))
+ require.Equal(t, "11", infraerrors.FromError(err).Metadata["group_id"])
+ require.Nil(t, repo.updates)
+}
+
+func TestSettingService_UpdateSettings_DefaultSubscriptions_RejectsDuplicateGroupWithoutGroupReader(t *testing.T) {
+ repo := &settingUpdateRepoStub{}
+ svc := NewSettingService(repo, &config.Config{})
+
+ err := svc.UpdateSettings(context.Background(), &SystemSettings{
+ DefaultSubscriptions: []DefaultSubscriptionSetting{
+ {GroupID: 11, ValidityDays: 30},
+ {GroupID: 11, ValidityDays: 60},
+ },
+ })
+ require.Error(t, err)
+ require.Equal(t, "DEFAULT_SUBSCRIPTION_GROUP_DUPLICATE", infraerrors.Reason(err))
+ require.Equal(t, "11", infraerrors.FromError(err).Metadata["group_id"])
+ require.Nil(t, repo.updates)
+}
+
+func TestParseDefaultSubscriptions_NormalizesValues(t *testing.T) {
+ got := parseDefaultSubscriptions(`[{"group_id":11,"validity_days":30},{"group_id":11,"validity_days":60},{"group_id":0,"validity_days":10},{"group_id":12,"validity_days":99999}]`)
+ require.Equal(t, []DefaultSubscriptionSetting{
+ {GroupID: 11, ValidityDays: 30},
+ {GroupID: 11, ValidityDays: 60},
+ {GroupID: 12, ValidityDays: MaxValidityDays},
+ }, got)
+}
diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go
index 0c7bab67..5a441ea1 100644
--- a/backend/internal/service/settings_view.go
+++ b/backend/internal/service/settings_view.go
@@ -39,9 +39,11 @@ type SystemSettings struct {
HideCcsImportButton bool
PurchaseSubscriptionEnabled bool
PurchaseSubscriptionURL string
+ SoraClientEnabled bool
- DefaultConcurrency int
- DefaultBalance float64
+ DefaultConcurrency int
+ DefaultBalance float64
+ DefaultSubscriptions []DefaultSubscriptionSetting
// Model fallback configuration
EnableModelFallback bool `json:"enable_model_fallback"`
@@ -59,6 +61,14 @@ type SystemSettings struct {
OpsRealtimeMonitoringEnabled bool
OpsQueryModeDefault string
OpsMetricsIntervalSeconds int
+
+ // Claude Code version check
+ MinClaudeCodeVersion string
+}
+
+type DefaultSubscriptionSetting struct {
+ GroupID int64 `json:"group_id"`
+ ValidityDays int `json:"validity_days"`
}
type PublicSettings struct {
@@ -81,11 +91,52 @@ type PublicSettings struct {
PurchaseSubscriptionEnabled bool
PurchaseSubscriptionURL string
+ SoraClientEnabled bool
LinuxDoOAuthEnabled bool
Version string
}
+// SoraS3Settings Sora S3 存储配置
+type SoraS3Settings struct {
+ Enabled bool `json:"enabled"`
+ Endpoint string `json:"endpoint"`
+ Region string `json:"region"`
+ Bucket string `json:"bucket"`
+ AccessKeyID string `json:"access_key_id"`
+ SecretAccessKey string `json:"secret_access_key"` // 仅内部使用,不直接返回前端
+ SecretAccessKeyConfigured bool `json:"secret_access_key_configured"` // 前端展示用
+ Prefix string `json:"prefix"`
+ ForcePathStyle bool `json:"force_path_style"`
+ CDNURL string `json:"cdn_url"`
+ DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
+}
+
+// SoraS3Profile Sora S3 多配置项(服务内部模型)
+type SoraS3Profile struct {
+ ProfileID string `json:"profile_id"`
+ Name string `json:"name"`
+ IsActive bool `json:"is_active"`
+ Enabled bool `json:"enabled"`
+ Endpoint string `json:"endpoint"`
+ Region string `json:"region"`
+ Bucket string `json:"bucket"`
+ AccessKeyID string `json:"access_key_id"`
+ SecretAccessKey string `json:"-"` // 仅内部使用,不直接返回前端
+ SecretAccessKeyConfigured bool `json:"secret_access_key_configured"` // 前端展示用
+ Prefix string `json:"prefix"`
+ ForcePathStyle bool `json:"force_path_style"`
+ CDNURL string `json:"cdn_url"`
+ DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
+ UpdatedAt string `json:"updated_at"`
+}
+
+// SoraS3ProfileList Sora S3 多配置列表
+type SoraS3ProfileList struct {
+ ActiveProfileID string `json:"active_profile_id"`
+ Items []SoraS3Profile `json:"items"`
+}
+
// StreamTimeoutSettings 流超时处理配置(仅控制超时后的处理方式,超时判定由网关配置控制)
type StreamTimeoutSettings struct {
// Enabled 是否启用流超时处理
diff --git a/backend/internal/service/sora_client.go b/backend/internal/service/sora_client.go
index 4680538c..0a914d2d 100644
--- a/backend/internal/service/sora_client.go
+++ b/backend/internal/service/sora_client.go
@@ -43,6 +43,7 @@ type SoraVideoRequest struct {
Frames int
Model string
Size string
+ VideoCount int
MediaID string
RemixTargetID string
CameoIDs []string
diff --git a/backend/internal/service/sora_gateway_service.go b/backend/internal/service/sora_gateway_service.go
index b8241eef..ab6871bb 100644
--- a/backend/internal/service/sora_gateway_service.go
+++ b/backend/internal/service/sora_gateway_service.go
@@ -21,6 +21,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/gin-gonic/gin"
)
@@ -63,8 +64,8 @@ var soraBlockedCIDRs = mustParseCIDRs([]string{
// SoraGatewayService handles forwarding requests to Sora upstream.
type SoraGatewayService struct {
soraClient SoraClient
- mediaStorage *SoraMediaStorage
rateLimitService *RateLimitService
+ httpUpstream HTTPUpstream // 用于 apikey 类型账号的 HTTP 透传
cfg *config.Config
}
@@ -100,14 +101,14 @@ type soraPreflightChecker interface {
func NewSoraGatewayService(
soraClient SoraClient,
- mediaStorage *SoraMediaStorage,
rateLimitService *RateLimitService,
+ httpUpstream HTTPUpstream,
cfg *config.Config,
) *SoraGatewayService {
return &SoraGatewayService{
soraClient: soraClient,
- mediaStorage: mediaStorage,
rateLimitService: rateLimitService,
+ httpUpstream: httpUpstream,
cfg: cfg,
}
}
@@ -115,6 +116,15 @@ func NewSoraGatewayService(
func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte, clientStream bool) (*ForwardResult, error) {
startTime := time.Now()
+ // apikey 类型账号:HTTP 透传到上游,不走 SoraSDKClient
+ if account.Type == AccountTypeAPIKey && account.GetBaseURL() != "" {
+ if s.httpUpstream == nil {
+ s.writeSoraError(c, http.StatusInternalServerError, "api_error", "HTTP upstream client not configured", clientStream)
+ return nil, errors.New("httpUpstream not configured for sora apikey forwarding")
+ }
+ return s.forwardToUpstream(ctx, c, account, body, clientStream, startTime)
+ }
+
if s.soraClient == nil || !s.soraClient.Enabled() {
if c != nil {
c.JSON(http.StatusServiceUnavailable, gin.H{
@@ -296,6 +306,7 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
taskID := ""
var err error
+ videoCount := parseSoraVideoCount(reqBody)
switch modelCfg.Type {
case "image":
taskID, err = s.soraClient.CreateImageTask(reqCtx, account, SoraImageRequest{
@@ -321,6 +332,7 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
Frames: modelCfg.Frames,
Model: modelCfg.Model,
Size: modelCfg.Size,
+ VideoCount: videoCount,
MediaID: mediaID,
RemixTargetID: remixTargetID,
CameoIDs: extractSoraCameoIDs(reqBody),
@@ -378,16 +390,9 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
}
}
+ // 直调路径(/sora/v1/chat/completions)保持纯透传,不执行本地/S3 媒体落盘。
+ // 媒体存储由客户端 API 路径(/api/v1/sora/generate)的异步流程负责。
finalURLs := s.normalizeSoraMediaURLs(mediaURLs)
- if len(mediaURLs) > 0 && s.mediaStorage != nil && s.mediaStorage.Enabled() {
- stored, storeErr := s.mediaStorage.StoreFromURLs(reqCtx, mediaType, mediaURLs)
- if storeErr != nil {
- // 存储失败时降级使用原始 URL,不中断用户请求
- log.Printf("[Sora] StoreFromURLs failed, falling back to original URLs: %v", storeErr)
- } else {
- finalURLs = s.normalizeSoraMediaURLs(stored)
- }
- }
if watermarkPostID != "" && watermarkOpts.DeletePost {
if deleteErr := s.soraClient.DeletePost(reqCtx, account, watermarkPostID); deleteErr != nil {
log.Printf("[Sora] delete post failed, post_id=%s err=%v", watermarkPostID, deleteErr)
@@ -463,6 +468,20 @@ func parseSoraCharacterOptions(body map[string]any) soraCharacterOptions {
}
}
+func parseSoraVideoCount(body map[string]any) int {
+ if body == nil {
+ return 1
+ }
+ keys := []string{"video_count", "videos", "n_variants"}
+ for _, key := range keys {
+ count := parseIntWithDefault(body, key, 0)
+ if count > 0 {
+ return clampInt(count, 1, 3)
+ }
+ }
+ return 1
+}
+
func parseBoolWithDefault(body map[string]any, key string, def bool) bool {
if body == nil {
return def
@@ -508,6 +527,42 @@ func parseStringWithDefault(body map[string]any, key, def string) string {
return def
}
+func parseIntWithDefault(body map[string]any, key string, def int) int {
+ if body == nil {
+ return def
+ }
+ val, ok := body[key]
+ if !ok {
+ return def
+ }
+ switch typed := val.(type) {
+ case int:
+ return typed
+ case int32:
+ return int(typed)
+ case int64:
+ return int(typed)
+ case float64:
+ return int(typed)
+ case string:
+ parsed, err := strconv.Atoi(strings.TrimSpace(typed))
+ if err == nil {
+ return parsed
+ }
+ }
+ return def
+}
+
+func clampInt(v, minVal, maxVal int) int {
+ if v < minVal {
+ return minVal
+ }
+ if v > maxVal {
+ return maxVal
+ }
+ return v
+}
+
func extractSoraCameoIDs(body map[string]any) []string {
if body == nil {
return nil
@@ -904,6 +959,21 @@ func (s *SoraGatewayService) handleSoraRequestError(ctx context.Context, account
}
var upstreamErr *SoraUpstreamError
if errors.As(err, &upstreamErr) {
+ accountID := int64(0)
+ if account != nil {
+ accountID = account.ID
+ }
+ logger.LegacyPrintf(
+ "service.sora",
+ "[SoraRawError] account_id=%d model=%s status=%d request_id=%s cf_ray=%s message=%s raw_body=%s",
+ accountID,
+ model,
+ upstreamErr.StatusCode,
+ strings.TrimSpace(upstreamErr.Headers.Get("x-request-id")),
+ strings.TrimSpace(upstreamErr.Headers.Get("cf-ray")),
+ strings.TrimSpace(upstreamErr.Message),
+ truncateForLog(upstreamErr.Body, 1024),
+ )
if s.rateLimitService != nil && account != nil {
s.rateLimitService.HandleUpstreamError(ctx, account, upstreamErr.StatusCode, upstreamErr.Headers, upstreamErr.Body)
}
diff --git a/backend/internal/service/sora_gateway_service_test.go b/backend/internal/service/sora_gateway_service_test.go
index 5888fe92..206636ff 100644
--- a/backend/internal/service/sora_gateway_service_test.go
+++ b/backend/internal/service/sora_gateway_service_test.go
@@ -179,6 +179,31 @@ func TestSoraGatewayService_ForwardStoryboardPrompt(t *testing.T) {
require.True(t, client.storyboard)
}
+func TestSoraGatewayService_ForwardVideoCount(t *testing.T) {
+ client := &stubSoraClientForPoll{
+ videoStatus: &SoraVideoTaskStatus{
+ Status: "completed",
+ URLs: []string{"https://example.com/v.mp4"},
+ },
+ }
+ cfg := &config.Config{
+ Sora: config.SoraConfig{
+ Client: config.SoraClientConfig{
+ PollIntervalSeconds: 1,
+ MaxPollAttempts: 1,
+ },
+ },
+ }
+ svc := NewSoraGatewayService(client, nil, nil, cfg)
+ account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive}
+ body := []byte(`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"cat running"}],"video_count":3,"stream":false}`)
+
+ result, err := svc.Forward(context.Background(), nil, account, body, false)
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.Equal(t, 3, client.videoReq.VideoCount)
+}
+
func TestSoraGatewayService_ForwardCharacterOnly(t *testing.T) {
client := &stubSoraClientForPoll{}
cfg := &config.Config{
@@ -524,3 +549,10 @@ func TestParseSoraWatermarkOptions_NumericBool(t *testing.T) {
require.True(t, opts.Enabled)
require.False(t, opts.FallbackOnFailure)
}
+
+func TestParseSoraVideoCount(t *testing.T) {
+ require.Equal(t, 1, parseSoraVideoCount(nil))
+ require.Equal(t, 2, parseSoraVideoCount(map[string]any{"video_count": float64(2)}))
+ require.Equal(t, 3, parseSoraVideoCount(map[string]any{"videos": "5"}))
+ require.Equal(t, 1, parseSoraVideoCount(map[string]any{"n_variants": 0}))
+}
diff --git a/backend/internal/service/sora_generation.go b/backend/internal/service/sora_generation.go
new file mode 100644
index 00000000..a704454b
--- /dev/null
+++ b/backend/internal/service/sora_generation.go
@@ -0,0 +1,63 @@
+package service
+
+import (
+ "context"
+ "time"
+)
+
+// SoraGeneration 代表一条 Sora 客户端生成记录。
+type SoraGeneration struct {
+ ID int64 `json:"id"`
+ UserID int64 `json:"user_id"`
+ APIKeyID *int64 `json:"api_key_id,omitempty"`
+ Model string `json:"model"`
+ Prompt string `json:"prompt"`
+ MediaType string `json:"media_type"` // video / image
+ Status string `json:"status"` // pending / generating / completed / failed / cancelled
+ MediaURL string `json:"media_url"` // 主媒体 URL(预签名或 CDN)
+ MediaURLs []string `json:"media_urls"` // 多图时的 URL 数组
+ FileSizeBytes int64 `json:"file_size_bytes"`
+ StorageType string `json:"storage_type"` // s3 / local / upstream / none
+ S3ObjectKeys []string `json:"s3_object_keys"` // S3 object key 数组
+ UpstreamTaskID string `json:"upstream_task_id"`
+ ErrorMessage string `json:"error_message"`
+ CreatedAt time.Time `json:"created_at"`
+ CompletedAt *time.Time `json:"completed_at,omitempty"`
+}
+
+// Sora 生成记录状态常量
+const (
+ SoraGenStatusPending = "pending"
+ SoraGenStatusGenerating = "generating"
+ SoraGenStatusCompleted = "completed"
+ SoraGenStatusFailed = "failed"
+ SoraGenStatusCancelled = "cancelled"
+)
+
+// Sora 存储类型常量
+const (
+ SoraStorageTypeS3 = "s3"
+ SoraStorageTypeLocal = "local"
+ SoraStorageTypeUpstream = "upstream"
+ SoraStorageTypeNone = "none"
+)
+
+// SoraGenerationListParams 查询生成记录的参数。
+type SoraGenerationListParams struct {
+ UserID int64
+ Status string // 可选筛选
+ StorageType string // 可选筛选
+ MediaType string // 可选筛选
+ Page int
+ PageSize int
+}
+
+// SoraGenerationRepository 生成记录持久化接口。
+type SoraGenerationRepository interface {
+ Create(ctx context.Context, gen *SoraGeneration) error
+ GetByID(ctx context.Context, id int64) (*SoraGeneration, error)
+ Update(ctx context.Context, gen *SoraGeneration) error
+ Delete(ctx context.Context, id int64) error
+ List(ctx context.Context, params SoraGenerationListParams) ([]*SoraGeneration, int64, error)
+ CountByUserAndStatus(ctx context.Context, userID int64, statuses []string) (int64, error)
+}
diff --git a/backend/internal/service/sora_generation_service.go b/backend/internal/service/sora_generation_service.go
new file mode 100644
index 00000000..22d5b519
--- /dev/null
+++ b/backend/internal/service/sora_generation_service.go
@@ -0,0 +1,332 @@
+package service
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "sync"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/pkg/logger"
+)
+
+var (
+ // ErrSoraGenerationConcurrencyLimit 表示用户进行中的任务数超限。
+ ErrSoraGenerationConcurrencyLimit = errors.New("sora generation concurrent limit exceeded")
+ // ErrSoraGenerationStateConflict 表示状态已发生变化(例如任务已取消)。
+ ErrSoraGenerationStateConflict = errors.New("sora generation state conflict")
+ // ErrSoraGenerationNotActive 表示任务不在可取消状态。
+ ErrSoraGenerationNotActive = errors.New("sora generation is not active")
+)
+
+const soraGenerationActiveLimit = 3
+
+type soraGenerationRepoAtomicCreator interface {
+ CreatePendingWithLimit(ctx context.Context, gen *SoraGeneration, activeStatuses []string, maxActive int64) error
+}
+
+type soraGenerationRepoConditionalUpdater interface {
+ UpdateGeneratingIfPending(ctx context.Context, id int64, upstreamTaskID string) (bool, error)
+ UpdateCompletedIfActive(ctx context.Context, id int64, mediaURL string, mediaURLs []string, storageType string, s3Keys []string, fileSizeBytes int64, completedAt time.Time) (bool, error)
+ UpdateFailedIfActive(ctx context.Context, id int64, errMsg string, completedAt time.Time) (bool, error)
+ UpdateCancelledIfActive(ctx context.Context, id int64, completedAt time.Time) (bool, error)
+ UpdateStorageIfCompleted(ctx context.Context, id int64, mediaURL string, mediaURLs []string, storageType string, s3Keys []string, fileSizeBytes int64) (bool, error)
+}
+
+// SoraGenerationService 管理 Sora 客户端的生成记录 CRUD。
+type SoraGenerationService struct {
+ genRepo SoraGenerationRepository
+ s3Storage *SoraS3Storage
+ quotaService *SoraQuotaService
+}
+
+// NewSoraGenerationService 创建生成记录服务。
+func NewSoraGenerationService(
+ genRepo SoraGenerationRepository,
+ s3Storage *SoraS3Storage,
+ quotaService *SoraQuotaService,
+) *SoraGenerationService {
+ return &SoraGenerationService{
+ genRepo: genRepo,
+ s3Storage: s3Storage,
+ quotaService: quotaService,
+ }
+}
+
+// CreatePending 创建一条 pending 状态的生成记录。
+func (s *SoraGenerationService) CreatePending(ctx context.Context, userID int64, apiKeyID *int64, model, prompt, mediaType string) (*SoraGeneration, error) {
+ gen := &SoraGeneration{
+ UserID: userID,
+ APIKeyID: apiKeyID,
+ Model: model,
+ Prompt: prompt,
+ MediaType: mediaType,
+ Status: SoraGenStatusPending,
+ StorageType: SoraStorageTypeNone,
+ }
+ if atomicCreator, ok := s.genRepo.(soraGenerationRepoAtomicCreator); ok {
+ if err := atomicCreator.CreatePendingWithLimit(
+ ctx,
+ gen,
+ []string{SoraGenStatusPending, SoraGenStatusGenerating},
+ soraGenerationActiveLimit,
+ ); err != nil {
+ if errors.Is(err, ErrSoraGenerationConcurrencyLimit) {
+ return nil, err
+ }
+ return nil, fmt.Errorf("create generation: %w", err)
+ }
+ logger.LegacyPrintf("service.sora_gen", "[SoraGen] 创建记录 id=%d user=%d model=%s", gen.ID, userID, model)
+ return gen, nil
+ }
+
+ if err := s.genRepo.Create(ctx, gen); err != nil {
+ return nil, fmt.Errorf("create generation: %w", err)
+ }
+ logger.LegacyPrintf("service.sora_gen", "[SoraGen] 创建记录 id=%d user=%d model=%s", gen.ID, userID, model)
+ return gen, nil
+}
+
+// MarkGenerating 标记为生成中。
+func (s *SoraGenerationService) MarkGenerating(ctx context.Context, id int64, upstreamTaskID string) error {
+ if updater, ok := s.genRepo.(soraGenerationRepoConditionalUpdater); ok {
+ updated, err := updater.UpdateGeneratingIfPending(ctx, id, upstreamTaskID)
+ if err != nil {
+ return err
+ }
+ if !updated {
+ return ErrSoraGenerationStateConflict
+ }
+ return nil
+ }
+
+ gen, err := s.genRepo.GetByID(ctx, id)
+ if err != nil {
+ return err
+ }
+ if gen.Status != SoraGenStatusPending {
+ return ErrSoraGenerationStateConflict
+ }
+ gen.Status = SoraGenStatusGenerating
+ gen.UpstreamTaskID = upstreamTaskID
+ return s.genRepo.Update(ctx, gen)
+}
+
+// MarkCompleted 标记为已完成。
+func (s *SoraGenerationService) MarkCompleted(ctx context.Context, id int64, mediaURL string, mediaURLs []string, storageType string, s3Keys []string, fileSizeBytes int64) error {
+ now := time.Now()
+ if updater, ok := s.genRepo.(soraGenerationRepoConditionalUpdater); ok {
+ updated, err := updater.UpdateCompletedIfActive(ctx, id, mediaURL, mediaURLs, storageType, s3Keys, fileSizeBytes, now)
+ if err != nil {
+ return err
+ }
+ if !updated {
+ return ErrSoraGenerationStateConflict
+ }
+ return nil
+ }
+
+ gen, err := s.genRepo.GetByID(ctx, id)
+ if err != nil {
+ return err
+ }
+ if gen.Status != SoraGenStatusPending && gen.Status != SoraGenStatusGenerating {
+ return ErrSoraGenerationStateConflict
+ }
+ gen.Status = SoraGenStatusCompleted
+ gen.MediaURL = mediaURL
+ gen.MediaURLs = mediaURLs
+ gen.StorageType = storageType
+ gen.S3ObjectKeys = s3Keys
+ gen.FileSizeBytes = fileSizeBytes
+ gen.CompletedAt = &now
+ return s.genRepo.Update(ctx, gen)
+}
+
+// MarkFailed 标记为失败。
+func (s *SoraGenerationService) MarkFailed(ctx context.Context, id int64, errMsg string) error {
+ now := time.Now()
+ if updater, ok := s.genRepo.(soraGenerationRepoConditionalUpdater); ok {
+ updated, err := updater.UpdateFailedIfActive(ctx, id, errMsg, now)
+ if err != nil {
+ return err
+ }
+ if !updated {
+ return ErrSoraGenerationStateConflict
+ }
+ return nil
+ }
+
+ gen, err := s.genRepo.GetByID(ctx, id)
+ if err != nil {
+ return err
+ }
+ if gen.Status != SoraGenStatusPending && gen.Status != SoraGenStatusGenerating {
+ return ErrSoraGenerationStateConflict
+ }
+ gen.Status = SoraGenStatusFailed
+ gen.ErrorMessage = errMsg
+ gen.CompletedAt = &now
+ return s.genRepo.Update(ctx, gen)
+}
+
+// MarkCancelled 标记为已取消。
+func (s *SoraGenerationService) MarkCancelled(ctx context.Context, id int64) error {
+ now := time.Now()
+ if updater, ok := s.genRepo.(soraGenerationRepoConditionalUpdater); ok {
+ updated, err := updater.UpdateCancelledIfActive(ctx, id, now)
+ if err != nil {
+ return err
+ }
+ if !updated {
+ return ErrSoraGenerationNotActive
+ }
+ return nil
+ }
+
+ gen, err := s.genRepo.GetByID(ctx, id)
+ if err != nil {
+ return err
+ }
+ if gen.Status != SoraGenStatusPending && gen.Status != SoraGenStatusGenerating {
+ return ErrSoraGenerationNotActive
+ }
+ gen.Status = SoraGenStatusCancelled
+ gen.CompletedAt = &now
+ return s.genRepo.Update(ctx, gen)
+}
+
+// UpdateStorageForCompleted 更新已完成记录的存储信息(不重置 completed_at)。
+func (s *SoraGenerationService) UpdateStorageForCompleted(
+ ctx context.Context,
+ id int64,
+ mediaURL string,
+ mediaURLs []string,
+ storageType string,
+ s3Keys []string,
+ fileSizeBytes int64,
+) error {
+ if updater, ok := s.genRepo.(soraGenerationRepoConditionalUpdater); ok {
+ updated, err := updater.UpdateStorageIfCompleted(ctx, id, mediaURL, mediaURLs, storageType, s3Keys, fileSizeBytes)
+ if err != nil {
+ return err
+ }
+ if !updated {
+ return ErrSoraGenerationStateConflict
+ }
+ return nil
+ }
+
+ gen, err := s.genRepo.GetByID(ctx, id)
+ if err != nil {
+ return err
+ }
+ if gen.Status != SoraGenStatusCompleted {
+ return ErrSoraGenerationStateConflict
+ }
+ gen.MediaURL = mediaURL
+ gen.MediaURLs = mediaURLs
+ gen.StorageType = storageType
+ gen.S3ObjectKeys = s3Keys
+ gen.FileSizeBytes = fileSizeBytes
+ return s.genRepo.Update(ctx, gen)
+}
+
+// GetByID 获取记录详情(含权限校验)。
+func (s *SoraGenerationService) GetByID(ctx context.Context, id, userID int64) (*SoraGeneration, error) {
+ gen, err := s.genRepo.GetByID(ctx, id)
+ if err != nil {
+ return nil, err
+ }
+ if gen.UserID != userID {
+ return nil, fmt.Errorf("无权访问此生成记录")
+ }
+ return gen, nil
+}
+
+// List 查询生成记录列表(分页 + 筛选)。
+func (s *SoraGenerationService) List(ctx context.Context, params SoraGenerationListParams) ([]*SoraGeneration, int64, error) {
+ if params.Page <= 0 {
+ params.Page = 1
+ }
+ if params.PageSize <= 0 {
+ params.PageSize = 20
+ }
+ if params.PageSize > 100 {
+ params.PageSize = 100
+ }
+ return s.genRepo.List(ctx, params)
+}
+
+// Delete 删除记录(联动 S3/本地文件清理 + 配额释放)。
+func (s *SoraGenerationService) Delete(ctx context.Context, id, userID int64) error {
+ gen, err := s.genRepo.GetByID(ctx, id)
+ if err != nil {
+ return err
+ }
+ if gen.UserID != userID {
+ return fmt.Errorf("无权删除此生成记录")
+ }
+
+ // 清理 S3 文件
+ if gen.StorageType == SoraStorageTypeS3 && len(gen.S3ObjectKeys) > 0 && s.s3Storage != nil {
+ if err := s.s3Storage.DeleteObjects(ctx, gen.S3ObjectKeys); err != nil {
+ logger.LegacyPrintf("service.sora_gen", "[SoraGen] S3 清理失败 id=%d err=%v", id, err)
+ }
+ }
+
+ // 释放配额(S3/本地均释放)
+ if gen.FileSizeBytes > 0 && (gen.StorageType == SoraStorageTypeS3 || gen.StorageType == SoraStorageTypeLocal) && s.quotaService != nil {
+ if err := s.quotaService.ReleaseUsage(ctx, userID, gen.FileSizeBytes); err != nil {
+ logger.LegacyPrintf("service.sora_gen", "[SoraGen] 配额释放失败 id=%d err=%v", id, err)
+ }
+ }
+
+ return s.genRepo.Delete(ctx, id)
+}
+
+// CountActiveByUser 统计用户进行中的任务数(用于并发限制)。
+func (s *SoraGenerationService) CountActiveByUser(ctx context.Context, userID int64) (int64, error) {
+ return s.genRepo.CountByUserAndStatus(ctx, userID, []string{SoraGenStatusPending, SoraGenStatusGenerating})
+}
+
+// ResolveMediaURLs 为 S3 记录动态生成预签名 URL。
+func (s *SoraGenerationService) ResolveMediaURLs(ctx context.Context, gen *SoraGeneration) error {
+ if gen == nil || gen.StorageType != SoraStorageTypeS3 || s.s3Storage == nil {
+ return nil
+ }
+ if len(gen.S3ObjectKeys) == 0 {
+ return nil
+ }
+
+ urls := make([]string, len(gen.S3ObjectKeys))
+ var wg sync.WaitGroup
+ var firstErr error
+ var errMu sync.Mutex
+
+ for idx, key := range gen.S3ObjectKeys {
+ wg.Add(1)
+ go func(i int, objectKey string) {
+ defer wg.Done()
+ url, err := s.s3Storage.GetAccessURL(ctx, objectKey)
+ if err != nil {
+ errMu.Lock()
+ if firstErr == nil {
+ firstErr = err
+ }
+ errMu.Unlock()
+ return
+ }
+ urls[i] = url
+ }(idx, key)
+ }
+ wg.Wait()
+ if firstErr != nil {
+ return firstErr
+ }
+
+ gen.MediaURL = urls[0]
+ gen.MediaURLs = urls
+
+ return nil
+}
diff --git a/backend/internal/service/sora_generation_service_test.go b/backend/internal/service/sora_generation_service_test.go
new file mode 100644
index 00000000..46f322c8
--- /dev/null
+++ b/backend/internal/service/sora_generation_service_test.go
@@ -0,0 +1,878 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "fmt"
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
+ "github.com/aws/aws-sdk-go-v2/service/s3"
+ "github.com/stretchr/testify/require"
+)
+
+// ==================== Stub: SoraGenerationRepository ====================
+
+var _ SoraGenerationRepository = (*stubGenRepo)(nil)
+
+type stubGenRepo struct {
+ gens map[int64]*SoraGeneration
+ nextID int64
+ createErr error
+ getErr error
+ updateErr error
+ deleteErr error
+ listErr error
+ countErr error
+ countValue int64
+}
+
+func newStubGenRepo() *stubGenRepo {
+ return &stubGenRepo{gens: make(map[int64]*SoraGeneration), nextID: 1}
+}
+
+func (r *stubGenRepo) Create(_ context.Context, gen *SoraGeneration) error {
+ if r.createErr != nil {
+ return r.createErr
+ }
+ gen.ID = r.nextID
+ gen.CreatedAt = time.Now()
+ r.nextID++
+ r.gens[gen.ID] = gen
+ return nil
+}
+
+func (r *stubGenRepo) GetByID(_ context.Context, id int64) (*SoraGeneration, error) {
+ if r.getErr != nil {
+ return nil, r.getErr
+ }
+ if gen, ok := r.gens[id]; ok {
+ return gen, nil
+ }
+ return nil, fmt.Errorf("not found")
+}
+
+func (r *stubGenRepo) Update(_ context.Context, gen *SoraGeneration) error {
+ if r.updateErr != nil {
+ return r.updateErr
+ }
+ r.gens[gen.ID] = gen
+ return nil
+}
+
+func (r *stubGenRepo) Delete(_ context.Context, id int64) error {
+ if r.deleteErr != nil {
+ return r.deleteErr
+ }
+ delete(r.gens, id)
+ return nil
+}
+
+func (r *stubGenRepo) List(_ context.Context, params SoraGenerationListParams) ([]*SoraGeneration, int64, error) {
+ if r.listErr != nil {
+ return nil, 0, r.listErr
+ }
+ var result []*SoraGeneration
+ for _, gen := range r.gens {
+ if gen.UserID != params.UserID {
+ continue
+ }
+ if params.Status != "" && gen.Status != params.Status {
+ continue
+ }
+ if params.StorageType != "" && gen.StorageType != params.StorageType {
+ continue
+ }
+ if params.MediaType != "" && gen.MediaType != params.MediaType {
+ continue
+ }
+ result = append(result, gen)
+ }
+ return result, int64(len(result)), nil
+}
+
+func (r *stubGenRepo) CountByUserAndStatus(_ context.Context, userID int64, statuses []string) (int64, error) {
+ if r.countErr != nil {
+ return 0, r.countErr
+ }
+ if r.countValue > 0 {
+ return r.countValue, nil
+ }
+ var count int64
+ statusSet := make(map[string]struct{})
+ for _, s := range statuses {
+ statusSet[s] = struct{}{}
+ }
+ for _, gen := range r.gens {
+ if gen.UserID == userID {
+ if _, ok := statusSet[gen.Status]; ok {
+ count++
+ }
+ }
+ }
+ return count, nil
+}
+
+// ==================== Stub: UserRepository (用于 SoraQuotaService) ====================
+
+var _ UserRepository = (*stubUserRepoForQuota)(nil)
+
+type stubUserRepoForQuota struct {
+ users map[int64]*User
+ updateErr error
+}
+
+func newStubUserRepoForQuota() *stubUserRepoForQuota {
+ return &stubUserRepoForQuota{users: make(map[int64]*User)}
+}
+
+func (r *stubUserRepoForQuota) GetByID(_ context.Context, id int64) (*User, error) {
+ if u, ok := r.users[id]; ok {
+ return u, nil
+ }
+ return nil, fmt.Errorf("user not found")
+}
+func (r *stubUserRepoForQuota) Update(_ context.Context, user *User) error {
+ if r.updateErr != nil {
+ return r.updateErr
+ }
+ r.users[user.ID] = user
+ return nil
+}
+func (r *stubUserRepoForQuota) Create(context.Context, *User) error { return nil }
+func (r *stubUserRepoForQuota) GetByEmail(context.Context, string) (*User, error) {
+ return nil, nil
+}
+func (r *stubUserRepoForQuota) GetFirstAdmin(context.Context) (*User, error) { return nil, nil }
+func (r *stubUserRepoForQuota) Delete(context.Context, int64) error { return nil }
+func (r *stubUserRepoForQuota) List(context.Context, pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) {
+ return nil, nil, nil
+}
+func (r *stubUserRepoForQuota) ListWithFilters(context.Context, pagination.PaginationParams, UserListFilters) ([]User, *pagination.PaginationResult, error) {
+ return nil, nil, nil
+}
+func (r *stubUserRepoForQuota) UpdateBalance(context.Context, int64, float64) error { return nil }
+func (r *stubUserRepoForQuota) DeductBalance(context.Context, int64, float64) error { return nil }
+func (r *stubUserRepoForQuota) UpdateConcurrency(context.Context, int64, int) error { return nil }
+func (r *stubUserRepoForQuota) ExistsByEmail(context.Context, string) (bool, error) {
+ return false, nil
+}
+func (r *stubUserRepoForQuota) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
+ return 0, nil
+}
+func (r *stubUserRepoForQuota) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
+func (r *stubUserRepoForQuota) EnableTotp(context.Context, int64) error { return nil }
+func (r *stubUserRepoForQuota) DisableTotp(context.Context, int64) error { return nil }
+func (r *stubUserRepoForQuota) AddGroupToAllowedGroups(context.Context, int64, int64) error {
+ return nil
+}
+
+// ==================== 辅助函数:构造带 CDN 缓存的 SoraS3Storage ====================
+
+// newS3StorageWithCDN 创建一个预缓存了 CDN 配置的 SoraS3Storage,
+// 避免实际初始化 AWS 客户端。用于测试 GetAccessURL 的 CDN 路径。
+func newS3StorageWithCDN(cdnURL string) *SoraS3Storage {
+ storage := &SoraS3Storage{}
+ storage.cfg = &SoraS3Settings{
+ Enabled: true,
+ Bucket: "test-bucket",
+ CDNURL: cdnURL,
+ }
+ // 需要 non-nil client 使 getClient 命中缓存
+ storage.client = s3.New(s3.Options{})
+ return storage
+}
+
+// newS3StorageFailingDelete 创建一个 settingService=nil 的 SoraS3Storage,
+// 使 DeleteObjects 返回错误(无法获取配置)。用于测试 Delete 方法 S3 清理失败但仍继续的场景。
+func newS3StorageFailingDelete() *SoraS3Storage {
+ return &SoraS3Storage{} // settingService 为 nil → getConfig 返回 error
+}
+
+// ==================== CreatePending ====================
+
+func TestCreatePending_Success(t *testing.T) {
+ repo := newStubGenRepo()
+ svc := NewSoraGenerationService(repo, nil, nil)
+
+ gen, err := svc.CreatePending(context.Background(), 1, nil, "sora2-landscape-10s", "一只猫跳舞", "video")
+ require.NoError(t, err)
+ require.Equal(t, int64(1), gen.ID)
+ require.Equal(t, int64(1), gen.UserID)
+ require.Equal(t, "sora2-landscape-10s", gen.Model)
+ require.Equal(t, "一只猫跳舞", gen.Prompt)
+ require.Equal(t, "video", gen.MediaType)
+ require.Equal(t, SoraGenStatusPending, gen.Status)
+ require.Equal(t, SoraStorageTypeNone, gen.StorageType)
+ require.Nil(t, gen.APIKeyID)
+}
+
+func TestCreatePending_WithAPIKeyID(t *testing.T) {
+ repo := newStubGenRepo()
+ svc := NewSoraGenerationService(repo, nil, nil)
+
+ apiKeyID := int64(42)
+ gen, err := svc.CreatePending(context.Background(), 1, &apiKeyID, "gpt-image", "画一朵花", "image")
+ require.NoError(t, err)
+ require.NotNil(t, gen.APIKeyID)
+ require.Equal(t, int64(42), *gen.APIKeyID)
+}
+
+func TestCreatePending_RepoError(t *testing.T) {
+ repo := newStubGenRepo()
+ repo.createErr = fmt.Errorf("db write error")
+ svc := NewSoraGenerationService(repo, nil, nil)
+
+ gen, err := svc.CreatePending(context.Background(), 1, nil, "sora2-landscape-10s", "test", "video")
+ require.Error(t, err)
+ require.Nil(t, gen)
+ require.Contains(t, err.Error(), "create generation")
+}
+
+// ==================== MarkGenerating ====================
+
+func TestMarkGenerating_Success(t *testing.T) {
+ repo := newStubGenRepo()
+ repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusPending}
+ svc := NewSoraGenerationService(repo, nil, nil)
+
+ err := svc.MarkGenerating(context.Background(), 1, "upstream-task-123")
+ require.NoError(t, err)
+ require.Equal(t, SoraGenStatusGenerating, repo.gens[1].Status)
+ require.Equal(t, "upstream-task-123", repo.gens[1].UpstreamTaskID)
+}
+
+func TestMarkGenerating_NotFound(t *testing.T) {
+ repo := newStubGenRepo()
+ svc := NewSoraGenerationService(repo, nil, nil)
+
+ err := svc.MarkGenerating(context.Background(), 999, "")
+ require.Error(t, err)
+}
+
+func TestMarkGenerating_UpdateError(t *testing.T) {
+ repo := newStubGenRepo()
+ repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusPending}
+ repo.updateErr = fmt.Errorf("update failed")
+ svc := NewSoraGenerationService(repo, nil, nil)
+
+ err := svc.MarkGenerating(context.Background(), 1, "")
+ require.Error(t, err)
+}
+
+// ==================== MarkCompleted ====================
+
+func TestMarkCompleted_Success(t *testing.T) {
+ repo := newStubGenRepo()
+ repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusGenerating}
+ svc := NewSoraGenerationService(repo, nil, nil)
+
+ err := svc.MarkCompleted(context.Background(), 1,
+ "https://cdn.example.com/video.mp4",
+ []string{"https://cdn.example.com/video.mp4"},
+ SoraStorageTypeS3,
+ []string{"sora/1/2024/01/01/uuid.mp4"},
+ 1048576,
+ )
+ require.NoError(t, err)
+ gen := repo.gens[1]
+ require.Equal(t, SoraGenStatusCompleted, gen.Status)
+ require.Equal(t, "https://cdn.example.com/video.mp4", gen.MediaURL)
+ require.Equal(t, []string{"https://cdn.example.com/video.mp4"}, gen.MediaURLs)
+ require.Equal(t, SoraStorageTypeS3, gen.StorageType)
+ require.Equal(t, []string{"sora/1/2024/01/01/uuid.mp4"}, gen.S3ObjectKeys)
+ require.Equal(t, int64(1048576), gen.FileSizeBytes)
+ require.NotNil(t, gen.CompletedAt)
+}
+
+func TestMarkCompleted_NotFound(t *testing.T) {
+ repo := newStubGenRepo()
+ svc := NewSoraGenerationService(repo, nil, nil)
+
+ err := svc.MarkCompleted(context.Background(), 999, "", nil, "", nil, 0)
+ require.Error(t, err)
+}
+
+func TestMarkCompleted_UpdateError(t *testing.T) {
+ repo := newStubGenRepo()
+ repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusGenerating}
+ repo.updateErr = fmt.Errorf("update failed")
+ svc := NewSoraGenerationService(repo, nil, nil)
+
+ err := svc.MarkCompleted(context.Background(), 1, "url", nil, SoraStorageTypeUpstream, nil, 0)
+ require.Error(t, err)
+}
+
+// ==================== MarkFailed ====================
+
+func TestMarkFailed_Success(t *testing.T) {
+ repo := newStubGenRepo()
+ repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusGenerating}
+ svc := NewSoraGenerationService(repo, nil, nil)
+
+ err := svc.MarkFailed(context.Background(), 1, "上游返回 500 错误")
+ require.NoError(t, err)
+ gen := repo.gens[1]
+ require.Equal(t, SoraGenStatusFailed, gen.Status)
+ require.Equal(t, "上游返回 500 错误", gen.ErrorMessage)
+ require.NotNil(t, gen.CompletedAt)
+}
+
+func TestMarkFailed_NotFound(t *testing.T) {
+ repo := newStubGenRepo()
+ svc := NewSoraGenerationService(repo, nil, nil)
+
+ err := svc.MarkFailed(context.Background(), 999, "error")
+ require.Error(t, err)
+}
+
+func TestMarkFailed_UpdateError(t *testing.T) {
+ repo := newStubGenRepo()
+ repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusGenerating}
+ repo.updateErr = fmt.Errorf("update failed")
+ svc := NewSoraGenerationService(repo, nil, nil)
+
+ err := svc.MarkFailed(context.Background(), 1, "err")
+ require.Error(t, err)
+}
+
+// ==================== MarkCancelled ====================
+
+func TestMarkCancelled_Pending(t *testing.T) {
+ repo := newStubGenRepo()
+ repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusPending}
+ svc := NewSoraGenerationService(repo, nil, nil)
+
+ err := svc.MarkCancelled(context.Background(), 1)
+ require.NoError(t, err)
+ require.Equal(t, SoraGenStatusCancelled, repo.gens[1].Status)
+ require.NotNil(t, repo.gens[1].CompletedAt)
+}
+
+func TestMarkCancelled_Generating(t *testing.T) {
+ repo := newStubGenRepo()
+ repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusGenerating}
+ svc := NewSoraGenerationService(repo, nil, nil)
+
+ err := svc.MarkCancelled(context.Background(), 1)
+ require.NoError(t, err)
+ require.Equal(t, SoraGenStatusCancelled, repo.gens[1].Status)
+}
+
+func TestMarkCancelled_Completed(t *testing.T) {
+ repo := newStubGenRepo()
+ repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusCompleted}
+ svc := NewSoraGenerationService(repo, nil, nil)
+
+ err := svc.MarkCancelled(context.Background(), 1)
+ require.Error(t, err)
+ require.ErrorIs(t, err, ErrSoraGenerationNotActive)
+}
+
+func TestMarkCancelled_Failed(t *testing.T) {
+ repo := newStubGenRepo()
+ repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusFailed}
+ svc := NewSoraGenerationService(repo, nil, nil)
+
+ err := svc.MarkCancelled(context.Background(), 1)
+ require.Error(t, err)
+}
+
+func TestMarkCancelled_AlreadyCancelled(t *testing.T) {
+ repo := newStubGenRepo()
+ repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusCancelled}
+ svc := NewSoraGenerationService(repo, nil, nil)
+
+ err := svc.MarkCancelled(context.Background(), 1)
+ require.Error(t, err)
+}
+
+func TestMarkCancelled_NotFound(t *testing.T) {
+ repo := newStubGenRepo()
+ svc := NewSoraGenerationService(repo, nil, nil)
+
+ err := svc.MarkCancelled(context.Background(), 999)
+ require.Error(t, err)
+}
+
+func TestMarkCancelled_UpdateError(t *testing.T) {
+ repo := newStubGenRepo()
+ repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusPending}
+ repo.updateErr = fmt.Errorf("update failed")
+ svc := NewSoraGenerationService(repo, nil, nil)
+
+ err := svc.MarkCancelled(context.Background(), 1)
+ require.Error(t, err)
+}
+
+// ==================== GetByID ====================
+
+func TestGetByID_Success(t *testing.T) {
+ repo := newStubGenRepo()
+ repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusCompleted, Model: "sora2-landscape-10s"}
+ svc := NewSoraGenerationService(repo, nil, nil)
+
+ gen, err := svc.GetByID(context.Background(), 1, 1)
+ require.NoError(t, err)
+ require.Equal(t, int64(1), gen.ID)
+ require.Equal(t, "sora2-landscape-10s", gen.Model)
+}
+
+func TestGetByID_WrongUser(t *testing.T) {
+ repo := newStubGenRepo()
+ repo.gens[1] = &SoraGeneration{ID: 1, UserID: 2, Status: SoraGenStatusCompleted}
+ svc := NewSoraGenerationService(repo, nil, nil)
+
+ gen, err := svc.GetByID(context.Background(), 1, 1)
+ require.Error(t, err)
+ require.Nil(t, gen)
+ require.Contains(t, err.Error(), "无权访问")
+}
+
+func TestGetByID_NotFound(t *testing.T) {
+ repo := newStubGenRepo()
+ svc := NewSoraGenerationService(repo, nil, nil)
+
+ gen, err := svc.GetByID(context.Background(), 999, 1)
+ require.Error(t, err)
+ require.Nil(t, gen)
+}
+
+// ==================== List ====================
+
+func TestList_Success(t *testing.T) {
+ repo := newStubGenRepo()
+ repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusCompleted, MediaType: "video"}
+ repo.gens[2] = &SoraGeneration{ID: 2, UserID: 1, Status: SoraGenStatusPending, MediaType: "image"}
+ repo.gens[3] = &SoraGeneration{ID: 3, UserID: 2, Status: SoraGenStatusCompleted, MediaType: "video"}
+ svc := NewSoraGenerationService(repo, nil, nil)
+
+ gens, total, err := svc.List(context.Background(), SoraGenerationListParams{UserID: 1, Page: 1, PageSize: 20})
+ require.NoError(t, err)
+ require.Len(t, gens, 2) // 只有 userID=1 的
+ require.Equal(t, int64(2), total)
+}
+
+func TestList_DefaultPagination(t *testing.T) {
+ repo := newStubGenRepo()
+ svc := NewSoraGenerationService(repo, nil, nil)
+
+ // page=0, pageSize=0 → 应修正为 page=1, pageSize=20
+ _, _, err := svc.List(context.Background(), SoraGenerationListParams{UserID: 1})
+ require.NoError(t, err)
+}
+
+func TestList_MaxPageSize(t *testing.T) {
+ repo := newStubGenRepo()
+ svc := NewSoraGenerationService(repo, nil, nil)
+
+ // pageSize > 100 → 应限制为 100
+ _, _, err := svc.List(context.Background(), SoraGenerationListParams{UserID: 1, Page: 1, PageSize: 200})
+ require.NoError(t, err)
+}
+
+func TestList_Error(t *testing.T) {
+ repo := newStubGenRepo()
+ repo.listErr = fmt.Errorf("db error")
+ svc := NewSoraGenerationService(repo, nil, nil)
+
+ _, _, err := svc.List(context.Background(), SoraGenerationListParams{UserID: 1})
+ require.Error(t, err)
+}
+
+// ==================== Delete ====================
+
+func TestDelete_Success(t *testing.T) {
+ repo := newStubGenRepo()
+ repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusCompleted, StorageType: SoraStorageTypeUpstream}
+ svc := NewSoraGenerationService(repo, nil, nil)
+
+ err := svc.Delete(context.Background(), 1, 1)
+ require.NoError(t, err)
+ _, exists := repo.gens[1]
+ require.False(t, exists)
+}
+
+func TestDelete_WrongUser(t *testing.T) {
+ repo := newStubGenRepo()
+ repo.gens[1] = &SoraGeneration{ID: 1, UserID: 2, Status: SoraGenStatusCompleted}
+ svc := NewSoraGenerationService(repo, nil, nil)
+
+ err := svc.Delete(context.Background(), 1, 1)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "无权删除")
+}
+
+func TestDelete_NotFound(t *testing.T) {
+ repo := newStubGenRepo()
+ svc := NewSoraGenerationService(repo, nil, nil)
+
+ err := svc.Delete(context.Background(), 999, 1)
+ require.Error(t, err)
+}
+
+func TestDelete_S3Cleanup_NilS3(t *testing.T) {
+ repo := newStubGenRepo()
+ repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, StorageType: SoraStorageTypeS3, S3ObjectKeys: []string{"key1"}}
+ svc := NewSoraGenerationService(repo, nil, nil)
+
+ err := svc.Delete(context.Background(), 1, 1)
+ require.NoError(t, err) // s3Storage 为 nil,跳过清理
+}
+
+func TestDelete_QuotaRelease_NilQuota(t *testing.T) {
+ repo := newStubGenRepo()
+ repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, StorageType: SoraStorageTypeS3, FileSizeBytes: 1024}
+ svc := NewSoraGenerationService(repo, nil, nil)
+
+ err := svc.Delete(context.Background(), 1, 1)
+ require.NoError(t, err) // quotaService 为 nil,跳过释放
+}
+
+func TestDelete_NonS3NoCleanup(t *testing.T) {
+ repo := newStubGenRepo()
+ repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, StorageType: SoraStorageTypeLocal, FileSizeBytes: 1024}
+ svc := NewSoraGenerationService(repo, nil, nil)
+
+ err := svc.Delete(context.Background(), 1, 1)
+ require.NoError(t, err)
+}
+
+func TestDelete_DeleteRepoError(t *testing.T) {
+ repo := newStubGenRepo()
+ repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, StorageType: SoraStorageTypeUpstream}
+ repo.deleteErr = fmt.Errorf("delete failed")
+ svc := NewSoraGenerationService(repo, nil, nil)
+
+ err := svc.Delete(context.Background(), 1, 1)
+ require.Error(t, err)
+}
+
+// ==================== CountActiveByUser ====================
+
+func TestCountActiveByUser_Success(t *testing.T) {
+ repo := newStubGenRepo()
+ repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusPending}
+ repo.gens[2] = &SoraGeneration{ID: 2, UserID: 1, Status: SoraGenStatusGenerating}
+ repo.gens[3] = &SoraGeneration{ID: 3, UserID: 1, Status: SoraGenStatusCompleted} // 不算
+ svc := NewSoraGenerationService(repo, nil, nil)
+
+ count, err := svc.CountActiveByUser(context.Background(), 1)
+ require.NoError(t, err)
+ require.Equal(t, int64(2), count)
+}
+
+func TestCountActiveByUser_NoActive(t *testing.T) {
+ repo := newStubGenRepo()
+ repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusCompleted}
+ svc := NewSoraGenerationService(repo, nil, nil)
+
+ count, err := svc.CountActiveByUser(context.Background(), 1)
+ require.NoError(t, err)
+ require.Equal(t, int64(0), count)
+}
+
+func TestCountActiveByUser_Error(t *testing.T) {
+ repo := newStubGenRepo()
+ repo.countErr = fmt.Errorf("db error")
+ svc := NewSoraGenerationService(repo, nil, nil)
+
+ _, err := svc.CountActiveByUser(context.Background(), 1)
+ require.Error(t, err)
+}
+
+// ==================== ResolveMediaURLs ====================
+
+func TestResolveMediaURLs_NilGen(t *testing.T) {
+ svc := NewSoraGenerationService(newStubGenRepo(), nil, nil)
+ require.NoError(t, svc.ResolveMediaURLs(context.Background(), nil))
+}
+
+func TestResolveMediaURLs_NonS3(t *testing.T) {
+ svc := NewSoraGenerationService(newStubGenRepo(), nil, nil)
+ gen := &SoraGeneration{StorageType: SoraStorageTypeUpstream, MediaURL: "https://original.com/v.mp4"}
+ require.NoError(t, svc.ResolveMediaURLs(context.Background(), gen))
+ require.Equal(t, "https://original.com/v.mp4", gen.MediaURL) // 不变
+}
+
+func TestResolveMediaURLs_S3NilStorage(t *testing.T) {
+ svc := NewSoraGenerationService(newStubGenRepo(), nil, nil)
+ gen := &SoraGeneration{StorageType: SoraStorageTypeS3, S3ObjectKeys: []string{"key1"}}
+ require.NoError(t, svc.ResolveMediaURLs(context.Background(), gen))
+}
+
+func TestResolveMediaURLs_Local(t *testing.T) {
+ svc := NewSoraGenerationService(newStubGenRepo(), nil, nil)
+ gen := &SoraGeneration{StorageType: SoraStorageTypeLocal, MediaURL: "/video/2024/01/01/file.mp4"}
+ require.NoError(t, svc.ResolveMediaURLs(context.Background(), gen))
+ require.Equal(t, "/video/2024/01/01/file.mp4", gen.MediaURL) // 不变
+}
+
+// ==================== 状态流转完整测试 ====================
+
+func TestStatusTransition_PendingToCompletedFlow(t *testing.T) {
+ repo := newStubGenRepo()
+ svc := NewSoraGenerationService(repo, nil, nil)
+
+ // 1. 创建 pending
+ gen, err := svc.CreatePending(context.Background(), 1, nil, "sora2-landscape-10s", "test", "video")
+ require.NoError(t, err)
+ require.Equal(t, SoraGenStatusPending, gen.Status)
+
+ // 2. 标记 generating
+ err = svc.MarkGenerating(context.Background(), gen.ID, "task-123")
+ require.NoError(t, err)
+ require.Equal(t, SoraGenStatusGenerating, repo.gens[gen.ID].Status)
+
+ // 3. 标记 completed
+ err = svc.MarkCompleted(context.Background(), gen.ID, "https://s3.com/video.mp4", nil, SoraStorageTypeS3, []string{"key"}, 1024)
+ require.NoError(t, err)
+ require.Equal(t, SoraGenStatusCompleted, repo.gens[gen.ID].Status)
+}
+
+func TestStatusTransition_PendingToFailedFlow(t *testing.T) {
+ repo := newStubGenRepo()
+ svc := NewSoraGenerationService(repo, nil, nil)
+
+ gen, _ := svc.CreatePending(context.Background(), 1, nil, "sora2-landscape-10s", "test", "video")
+ _ = svc.MarkGenerating(context.Background(), gen.ID, "")
+
+ err := svc.MarkFailed(context.Background(), gen.ID, "上游超时")
+ require.NoError(t, err)
+ require.Equal(t, SoraGenStatusFailed, repo.gens[gen.ID].Status)
+ require.Equal(t, "上游超时", repo.gens[gen.ID].ErrorMessage)
+}
+
+func TestStatusTransition_PendingToCancelledFlow(t *testing.T) {
+ repo := newStubGenRepo()
+ svc := NewSoraGenerationService(repo, nil, nil)
+
+ gen, _ := svc.CreatePending(context.Background(), 1, nil, "sora2-landscape-10s", "test", "video")
+ err := svc.MarkCancelled(context.Background(), gen.ID)
+ require.NoError(t, err)
+ require.Equal(t, SoraGenStatusCancelled, repo.gens[gen.ID].Status)
+}
+
+func TestStatusTransition_GeneratingToCancelledFlow(t *testing.T) {
+ repo := newStubGenRepo()
+ svc := NewSoraGenerationService(repo, nil, nil)
+
+ gen, _ := svc.CreatePending(context.Background(), 1, nil, "sora2-landscape-10s", "test", "video")
+ _ = svc.MarkGenerating(context.Background(), gen.ID, "")
+ err := svc.MarkCancelled(context.Background(), gen.ID)
+ require.NoError(t, err)
+ require.Equal(t, SoraGenStatusCancelled, repo.gens[gen.ID].Status)
+}
+
+// ==================== 权限隔离测试 ====================
+
+func TestUserIsolation_CannotAccessOthersRecord(t *testing.T) {
+ repo := newStubGenRepo()
+ svc := NewSoraGenerationService(repo, nil, nil)
+
+ gen, _ := svc.CreatePending(context.Background(), 1, nil, "sora2-landscape-10s", "test", "video")
+
+ // 用户 2 尝试访问用户 1 的记录
+ _, err := svc.GetByID(context.Background(), gen.ID, 2)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "无权访问")
+}
+
+func TestUserIsolation_CannotDeleteOthersRecord(t *testing.T) {
+ repo := newStubGenRepo()
+ svc := NewSoraGenerationService(repo, nil, nil)
+
+ gen, _ := svc.CreatePending(context.Background(), 1, nil, "sora2-landscape-10s", "test", "video")
+
+ err := svc.Delete(context.Background(), gen.ID, 2)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "无权删除")
+}
+
+// ==================== Delete: S3 清理 + 配额释放路径 ====================
+
+func TestDelete_S3Cleanup_WithS3Storage(t *testing.T) {
+ // S3 存储存在但 deleteObjects 会失败(settingService=nil),
+ // 验证 Delete 仍然成功(S3 错误只是记录日志)
+ repo := newStubGenRepo()
+ repo.gens[1] = &SoraGeneration{
+ ID: 1, UserID: 1,
+ StorageType: SoraStorageTypeS3,
+ S3ObjectKeys: []string{"sora/1/2024/01/01/abc.mp4"},
+ }
+ s3Storage := newS3StorageFailingDelete()
+ svc := NewSoraGenerationService(repo, s3Storage, nil)
+
+ err := svc.Delete(context.Background(), 1, 1)
+ require.NoError(t, err) // S3 清理失败不影响删除
+ _, exists := repo.gens[1]
+ require.False(t, exists)
+}
+
+func TestDelete_QuotaRelease_WithQuotaService(t *testing.T) {
+ // 有配额服务时,删除 S3 类型记录会释放配额
+ repo := newStubGenRepo()
+ repo.gens[1] = &SoraGeneration{
+ ID: 1, UserID: 1,
+ StorageType: SoraStorageTypeS3,
+ FileSizeBytes: 1048576, // 1MB
+ }
+
+ userRepo := newStubUserRepoForQuota()
+ userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 2097152} // 2MB
+ quotaService := NewSoraQuotaService(userRepo, nil, nil)
+
+ svc := NewSoraGenerationService(repo, nil, quotaService)
+ err := svc.Delete(context.Background(), 1, 1)
+ require.NoError(t, err)
+ // 配额应被释放: 2MB - 1MB = 1MB
+ require.Equal(t, int64(1048576), userRepo.users[1].SoraStorageUsedBytes)
+}
+
+func TestDelete_S3Cleanup_And_QuotaRelease(t *testing.T) {
+ // S3 清理 + 配额释放同时触发
+ repo := newStubGenRepo()
+ repo.gens[1] = &SoraGeneration{
+ ID: 1, UserID: 1,
+ StorageType: SoraStorageTypeS3,
+ S3ObjectKeys: []string{"key1"},
+ FileSizeBytes: 512,
+ }
+
+ userRepo := newStubUserRepoForQuota()
+ userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 1024}
+ quotaService := NewSoraQuotaService(userRepo, nil, nil)
+ s3Storage := newS3StorageFailingDelete()
+
+ svc := NewSoraGenerationService(repo, s3Storage, quotaService)
+ err := svc.Delete(context.Background(), 1, 1)
+ require.NoError(t, err)
+ _, exists := repo.gens[1]
+ require.False(t, exists)
+ require.Equal(t, int64(512), userRepo.users[1].SoraStorageUsedBytes)
+}
+
+func TestDelete_QuotaRelease_LocalStorage(t *testing.T) {
+ // 本地存储同样需要释放配额
+ repo := newStubGenRepo()
+ repo.gens[1] = &SoraGeneration{
+ ID: 1, UserID: 1,
+ StorageType: SoraStorageTypeLocal,
+ FileSizeBytes: 1024,
+ }
+
+ userRepo := newStubUserRepoForQuota()
+ userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 2048}
+ quotaService := NewSoraQuotaService(userRepo, nil, nil)
+
+ svc := NewSoraGenerationService(repo, nil, quotaService)
+ err := svc.Delete(context.Background(), 1, 1)
+ require.NoError(t, err)
+ require.Equal(t, int64(1024), userRepo.users[1].SoraStorageUsedBytes)
+}
+
+func TestDelete_QuotaRelease_ZeroFileSize(t *testing.T) {
+ // FileSizeBytes=0 跳过配额释放
+ repo := newStubGenRepo()
+ repo.gens[1] = &SoraGeneration{
+ ID: 1, UserID: 1,
+ StorageType: SoraStorageTypeS3,
+ FileSizeBytes: 0,
+ }
+
+ userRepo := newStubUserRepoForQuota()
+ userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 1024}
+ quotaService := NewSoraQuotaService(userRepo, nil, nil)
+
+ svc := NewSoraGenerationService(repo, nil, quotaService)
+ err := svc.Delete(context.Background(), 1, 1)
+ require.NoError(t, err)
+ require.Equal(t, int64(1024), userRepo.users[1].SoraStorageUsedBytes)
+}
+
+// ==================== ResolveMediaURLs: S3 + CDN 路径 ====================
+
+func TestResolveMediaURLs_S3_CDN_SingleKey(t *testing.T) {
+ s3Storage := newS3StorageWithCDN("https://cdn.example.com")
+ svc := NewSoraGenerationService(newStubGenRepo(), s3Storage, nil)
+
+ gen := &SoraGeneration{
+ StorageType: SoraStorageTypeS3,
+ S3ObjectKeys: []string{"sora/1/2024/01/01/video.mp4"},
+ MediaURL: "original",
+ }
+ err := svc.ResolveMediaURLs(context.Background(), gen)
+ require.NoError(t, err)
+ require.Equal(t, "https://cdn.example.com/sora/1/2024/01/01/video.mp4", gen.MediaURL)
+}
+
+func TestResolveMediaURLs_S3_CDN_MultipleKeys(t *testing.T) {
+ s3Storage := newS3StorageWithCDN("https://cdn.example.com/")
+ svc := NewSoraGenerationService(newStubGenRepo(), s3Storage, nil)
+
+ gen := &SoraGeneration{
+ StorageType: SoraStorageTypeS3,
+ S3ObjectKeys: []string{
+ "sora/1/2024/01/01/img1.png",
+ "sora/1/2024/01/01/img2.png",
+ "sora/1/2024/01/01/img3.png",
+ },
+ MediaURL: "original",
+ }
+ err := svc.ResolveMediaURLs(context.Background(), gen)
+ require.NoError(t, err)
+ // 主 URL 更新为第一个 key 的 CDN URL
+ require.Equal(t, "https://cdn.example.com/sora/1/2024/01/01/img1.png", gen.MediaURL)
+ // 多图 URLs 全部更新
+ require.Len(t, gen.MediaURLs, 3)
+ require.Equal(t, "https://cdn.example.com/sora/1/2024/01/01/img1.png", gen.MediaURLs[0])
+ require.Equal(t, "https://cdn.example.com/sora/1/2024/01/01/img2.png", gen.MediaURLs[1])
+ require.Equal(t, "https://cdn.example.com/sora/1/2024/01/01/img3.png", gen.MediaURLs[2])
+}
+
+func TestResolveMediaURLs_S3_EmptyKeys(t *testing.T) {
+ s3Storage := newS3StorageWithCDN("https://cdn.example.com")
+ svc := NewSoraGenerationService(newStubGenRepo(), s3Storage, nil)
+
+ gen := &SoraGeneration{
+ StorageType: SoraStorageTypeS3,
+ S3ObjectKeys: []string{},
+ MediaURL: "original",
+ }
+ err := svc.ResolveMediaURLs(context.Background(), gen)
+ require.NoError(t, err)
+ require.Equal(t, "original", gen.MediaURL) // 不变
+}
+
+func TestResolveMediaURLs_S3_GetAccessURL_Error(t *testing.T) {
+ // 使用无 settingService 的 S3 Storage,getClient 会失败
+ s3Storage := newS3StorageFailingDelete() // 同样 GetAccessURL 也会失败
+ svc := NewSoraGenerationService(newStubGenRepo(), s3Storage, nil)
+
+ gen := &SoraGeneration{
+ StorageType: SoraStorageTypeS3,
+ S3ObjectKeys: []string{"sora/1/2024/01/01/video.mp4"},
+ MediaURL: "original",
+ }
+ err := svc.ResolveMediaURLs(context.Background(), gen)
+ require.Error(t, err) // GetAccessURL 失败应传播错误
+}
+
+func TestResolveMediaURLs_S3_MultiKey_ErrorOnSecond(t *testing.T) {
+ // 只有一个 key 时走主 URL 路径成功,但多 key 路径的错误也需覆盖
+ s3Storage := newS3StorageFailingDelete()
+ svc := NewSoraGenerationService(newStubGenRepo(), s3Storage, nil)
+
+ gen := &SoraGeneration{
+ StorageType: SoraStorageTypeS3,
+ S3ObjectKeys: []string{
+ "sora/1/2024/01/01/img1.png",
+ "sora/1/2024/01/01/img2.png",
+ },
+ MediaURL: "original",
+ }
+ err := svc.ResolveMediaURLs(context.Background(), gen)
+ require.Error(t, err) // 第一个 key 的 GetAccessURL 就会失败
+}
diff --git a/backend/internal/service/sora_media_storage.go b/backend/internal/service/sora_media_storage.go
index eb363c4f..18783865 100644
--- a/backend/internal/service/sora_media_storage.go
+++ b/backend/internal/service/sora_media_storage.go
@@ -157,6 +157,64 @@ func (s *SoraMediaStorage) StoreFromURLs(ctx context.Context, mediaType string,
return results, nil
}
+// TotalSizeByRelativePaths 统计本地存储路径总大小(仅统计 /image 和 /video 路径)。
+func (s *SoraMediaStorage) TotalSizeByRelativePaths(paths []string) (int64, error) {
+ if s == nil || len(paths) == 0 {
+ return 0, nil
+ }
+ var total int64
+ for _, p := range paths {
+ localPath, err := s.resolveLocalPath(p)
+ if err != nil {
+ continue
+ }
+ info, err := os.Stat(localPath)
+ if err != nil {
+ if os.IsNotExist(err) {
+ continue
+ }
+ return 0, err
+ }
+ if info.Mode().IsRegular() {
+ total += info.Size()
+ }
+ }
+ return total, nil
+}
+
+// DeleteByRelativePaths 删除本地媒体路径(仅删除 /image 和 /video 路径)。
+func (s *SoraMediaStorage) DeleteByRelativePaths(paths []string) error {
+ if s == nil || len(paths) == 0 {
+ return nil
+ }
+ var lastErr error
+ for _, p := range paths {
+ localPath, err := s.resolveLocalPath(p)
+ if err != nil {
+ continue
+ }
+ if err := os.Remove(localPath); err != nil && !os.IsNotExist(err) {
+ lastErr = err
+ }
+ }
+ return lastErr
+}
+
+func (s *SoraMediaStorage) resolveLocalPath(relativePath string) (string, error) {
+ if s == nil || strings.TrimSpace(relativePath) == "" {
+ return "", errors.New("empty path")
+ }
+ cleaned := path.Clean(relativePath)
+ if !strings.HasPrefix(cleaned, "/image/") && !strings.HasPrefix(cleaned, "/video/") {
+ return "", errors.New("not a local media path")
+ }
+ if strings.TrimSpace(s.root) == "" {
+ return "", errors.New("storage root not configured")
+ }
+ relative := strings.TrimPrefix(cleaned, "/")
+ return filepath.Join(s.root, filepath.FromSlash(relative)), nil
+}
+
func (s *SoraMediaStorage) downloadAndStore(ctx context.Context, mediaType, rawURL string) (string, error) {
if strings.TrimSpace(rawURL) == "" {
return "", errors.New("empty url")
diff --git a/backend/internal/service/sora_models.go b/backend/internal/service/sora_models.go
index 80b20a4b..53d4c788 100644
--- a/backend/internal/service/sora_models.go
+++ b/backend/internal/service/sora_models.go
@@ -1,6 +1,9 @@
package service
import (
+ "regexp"
+ "sort"
+ "strconv"
"strings"
"github.com/Wei-Shaw/sub2api/internal/config"
@@ -247,6 +250,218 @@ func GetSoraModelConfig(model string) (SoraModelConfig, bool) {
return cfg, ok
}
+// SoraModelFamily 模型家族(前端 Sora 客户端使用)
+type SoraModelFamily struct {
+ ID string `json:"id"`
+ Name string `json:"name"`
+ Type string `json:"type"`
+ Orientations []string `json:"orientations"`
+ Durations []int `json:"durations,omitempty"`
+}
+
+var (
+ videoSuffixRe = regexp.MustCompile(`-(landscape|portrait)-(\d+)s$`)
+ imageSuffixRe = regexp.MustCompile(`-(landscape|portrait)$`)
+
+ soraFamilyNames = map[string]string{
+ "sora2": "Sora 2",
+ "sora2pro": "Sora 2 Pro",
+ "sora2pro-hd": "Sora 2 Pro HD",
+ "gpt-image": "GPT Image",
+ }
+)
+
+// BuildSoraModelFamilies 从 soraModelConfigs 自动聚合模型家族及其支持的方向和时长
+func BuildSoraModelFamilies() []SoraModelFamily {
+ type familyData struct {
+ modelType string
+ orientations map[string]bool
+ durations map[int]bool
+ }
+ families := make(map[string]*familyData)
+
+ for id, cfg := range soraModelConfigs {
+ if cfg.Type == "prompt_enhance" {
+ continue
+ }
+ var famID, orientation string
+ var duration int
+
+ switch cfg.Type {
+ case "video":
+ if m := videoSuffixRe.FindStringSubmatch(id); m != nil {
+ famID = id[:len(id)-len(m[0])]
+ orientation = m[1]
+ duration, _ = strconv.Atoi(m[2])
+ }
+ case "image":
+ if m := imageSuffixRe.FindStringSubmatch(id); m != nil {
+ famID = id[:len(id)-len(m[0])]
+ orientation = m[1]
+ } else {
+ famID = id
+ orientation = "square"
+ }
+ }
+ if famID == "" {
+ continue
+ }
+
+ fd, ok := families[famID]
+ if !ok {
+ fd = &familyData{
+ modelType: cfg.Type,
+ orientations: make(map[string]bool),
+ durations: make(map[int]bool),
+ }
+ families[famID] = fd
+ }
+ if orientation != "" {
+ fd.orientations[orientation] = true
+ }
+ if duration > 0 {
+ fd.durations[duration] = true
+ }
+ }
+
+ // 排序:视频在前、图像在后,同类按名称排序
+ famIDs := make([]string, 0, len(families))
+ for id := range families {
+ famIDs = append(famIDs, id)
+ }
+ sort.Slice(famIDs, func(i, j int) bool {
+ fi, fj := families[famIDs[i]], families[famIDs[j]]
+ if fi.modelType != fj.modelType {
+ return fi.modelType == "video"
+ }
+ return famIDs[i] < famIDs[j]
+ })
+
+ result := make([]SoraModelFamily, 0, len(famIDs))
+ for _, famID := range famIDs {
+ fd := families[famID]
+ fam := SoraModelFamily{
+ ID: famID,
+ Name: soraFamilyNames[famID],
+ Type: fd.modelType,
+ }
+ if fam.Name == "" {
+ fam.Name = famID
+ }
+ for o := range fd.orientations {
+ fam.Orientations = append(fam.Orientations, o)
+ }
+ sort.Strings(fam.Orientations)
+ for d := range fd.durations {
+ fam.Durations = append(fam.Durations, d)
+ }
+ sort.Ints(fam.Durations)
+ result = append(result, fam)
+ }
+ return result
+}
+
+// BuildSoraModelFamiliesFromIDs 从任意模型 ID 列表聚合模型家族(用于解析上游返回的模型列表)。
+// 通过命名约定自动识别视频/图像模型并分组。
+func BuildSoraModelFamiliesFromIDs(modelIDs []string) []SoraModelFamily {
+ type familyData struct {
+ modelType string
+ orientations map[string]bool
+ durations map[int]bool
+ }
+ families := make(map[string]*familyData)
+
+ for _, id := range modelIDs {
+ id = strings.ToLower(strings.TrimSpace(id))
+ if id == "" || strings.HasPrefix(id, "prompt-enhance") {
+ continue
+ }
+
+ var famID, orientation, modelType string
+ var duration int
+
+ if m := videoSuffixRe.FindStringSubmatch(id); m != nil {
+ // 视频模型: {family}-{orientation}-{duration}s
+ famID = id[:len(id)-len(m[0])]
+ orientation = m[1]
+ duration, _ = strconv.Atoi(m[2])
+ modelType = "video"
+ } else if m := imageSuffixRe.FindStringSubmatch(id); m != nil {
+ // 图像模型(带方向): {family}-{orientation}
+ famID = id[:len(id)-len(m[0])]
+ orientation = m[1]
+ modelType = "image"
+ } else if cfg, ok := soraModelConfigs[id]; ok && cfg.Type == "image" {
+ // 已知的无后缀图像模型(如 gpt-image)
+ famID = id
+ orientation = "square"
+ modelType = "image"
+ } else if strings.Contains(id, "image") {
+ // 未知但名称包含 image 的模型,推断为图像模型
+ famID = id
+ orientation = "square"
+ modelType = "image"
+ } else {
+ continue
+ }
+
+ if famID == "" {
+ continue
+ }
+
+ fd, ok := families[famID]
+ if !ok {
+ fd = &familyData{
+ modelType: modelType,
+ orientations: make(map[string]bool),
+ durations: make(map[int]bool),
+ }
+ families[famID] = fd
+ }
+ if orientation != "" {
+ fd.orientations[orientation] = true
+ }
+ if duration > 0 {
+ fd.durations[duration] = true
+ }
+ }
+
+ famIDs := make([]string, 0, len(families))
+ for id := range families {
+ famIDs = append(famIDs, id)
+ }
+ sort.Slice(famIDs, func(i, j int) bool {
+ fi, fj := families[famIDs[i]], families[famIDs[j]]
+ if fi.modelType != fj.modelType {
+ return fi.modelType == "video"
+ }
+ return famIDs[i] < famIDs[j]
+ })
+
+ result := make([]SoraModelFamily, 0, len(famIDs))
+ for _, famID := range famIDs {
+ fd := families[famID]
+ fam := SoraModelFamily{
+ ID: famID,
+ Name: soraFamilyNames[famID],
+ Type: fd.modelType,
+ }
+ if fam.Name == "" {
+ fam.Name = famID
+ }
+ for o := range fd.orientations {
+ fam.Orientations = append(fam.Orientations, o)
+ }
+ sort.Strings(fam.Orientations)
+ for d := range fd.durations {
+ fam.Durations = append(fam.Durations, d)
+ }
+ sort.Ints(fam.Durations)
+ result = append(result, fam)
+ }
+ return result
+}
+
// DefaultSoraModels returns the default Sora model list.
func DefaultSoraModels(cfg *config.Config) []openai.Model {
models := make([]openai.Model, 0, len(soraModelIDs))
diff --git a/backend/internal/service/sora_quota_service.go b/backend/internal/service/sora_quota_service.go
new file mode 100644
index 00000000..f0843374
--- /dev/null
+++ b/backend/internal/service/sora_quota_service.go
@@ -0,0 +1,257 @@
+package service
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "strconv"
+
+ "github.com/Wei-Shaw/sub2api/internal/pkg/logger"
+)
+
+// SoraQuotaService 管理 Sora 用户存储配额。
+// 配额优先级:用户级 → 分组级 → 系统默认值。
+type SoraQuotaService struct {
+ userRepo UserRepository
+ groupRepo GroupRepository
+ settingService *SettingService
+}
+
+// NewSoraQuotaService 创建配额服务实例。
+func NewSoraQuotaService(
+ userRepo UserRepository,
+ groupRepo GroupRepository,
+ settingService *SettingService,
+) *SoraQuotaService {
+ return &SoraQuotaService{
+ userRepo: userRepo,
+ groupRepo: groupRepo,
+ settingService: settingService,
+ }
+}
+
+// QuotaInfo 返回给客户端的配额信息。
+type QuotaInfo struct {
+ QuotaBytes int64 `json:"quota_bytes"` // 总配额(0 表示无限制)
+ UsedBytes int64 `json:"used_bytes"` // 已使用
+ AvailableBytes int64 `json:"available_bytes"` // 剩余可用(无限制时为 0)
+ QuotaSource string `json:"quota_source"` // 配额来源:user / group / system / unlimited
+ Source string `json:"source,omitempty"` // 兼容旧字段
+}
+
+// ErrSoraStorageQuotaExceeded 表示配额不足。
+var ErrSoraStorageQuotaExceeded = errors.New("sora storage quota exceeded")
+
+// QuotaExceededError 包含配额不足的上下文信息。
+type QuotaExceededError struct {
+ QuotaBytes int64
+ UsedBytes int64
+}
+
+func (e *QuotaExceededError) Error() string {
+ if e == nil {
+ return "存储配额不足"
+ }
+ return fmt.Sprintf("存储配额不足(已用 %d / 配额 %d 字节)", e.UsedBytes, e.QuotaBytes)
+}
+
+type soraQuotaAtomicUserRepository interface {
+ AddSoraStorageUsageWithQuota(ctx context.Context, userID int64, deltaBytes int64, effectiveQuota int64) (int64, error)
+ ReleaseSoraStorageUsageAtomic(ctx context.Context, userID int64, deltaBytes int64) (int64, error)
+}
+
+// GetQuota 获取用户的存储配额信息。
+// 优先级:用户级 > 用户所属分组级 > 系统默认值。
+func (s *SoraQuotaService) GetQuota(ctx context.Context, userID int64) (*QuotaInfo, error) {
+ user, err := s.userRepo.GetByID(ctx, userID)
+ if err != nil {
+ return nil, fmt.Errorf("get user: %w", err)
+ }
+
+ info := &QuotaInfo{
+ UsedBytes: user.SoraStorageUsedBytes,
+ }
+
+ // 1. 用户级配额
+ if user.SoraStorageQuotaBytes > 0 {
+ info.QuotaBytes = user.SoraStorageQuotaBytes
+ info.QuotaSource = "user"
+ info.Source = info.QuotaSource
+ info.AvailableBytes = calcAvailableBytes(info.QuotaBytes, info.UsedBytes)
+ return info, nil
+ }
+
+ // 2. 分组级配额(取用户可用分组中最大的配额)
+ if len(user.AllowedGroups) > 0 {
+ var maxGroupQuota int64
+ for _, gid := range user.AllowedGroups {
+ group, err := s.groupRepo.GetByID(ctx, gid)
+ if err != nil {
+ continue
+ }
+ if group.SoraStorageQuotaBytes > maxGroupQuota {
+ maxGroupQuota = group.SoraStorageQuotaBytes
+ }
+ }
+ if maxGroupQuota > 0 {
+ info.QuotaBytes = maxGroupQuota
+ info.QuotaSource = "group"
+ info.Source = info.QuotaSource
+ info.AvailableBytes = calcAvailableBytes(info.QuotaBytes, info.UsedBytes)
+ return info, nil
+ }
+ }
+
+ // 3. 系统默认值
+ defaultQuota := s.getSystemDefaultQuota(ctx)
+ if defaultQuota > 0 {
+ info.QuotaBytes = defaultQuota
+ info.QuotaSource = "system"
+ info.Source = info.QuotaSource
+ info.AvailableBytes = calcAvailableBytes(info.QuotaBytes, info.UsedBytes)
+ return info, nil
+ }
+
+ // 无配额限制
+ info.QuotaSource = "unlimited"
+ info.Source = info.QuotaSource
+ info.AvailableBytes = 0
+ return info, nil
+}
+
+// CheckQuota 检查用户是否有足够的存储配额。
+// 返回 nil 表示配额充足或无限制。
+func (s *SoraQuotaService) CheckQuota(ctx context.Context, userID int64, additionalBytes int64) error {
+ quota, err := s.GetQuota(ctx, userID)
+ if err != nil {
+ return err
+ }
+ // 0 表示无限制
+ if quota.QuotaBytes == 0 {
+ return nil
+ }
+ if quota.UsedBytes+additionalBytes > quota.QuotaBytes {
+ return &QuotaExceededError{
+ QuotaBytes: quota.QuotaBytes,
+ UsedBytes: quota.UsedBytes,
+ }
+ }
+ return nil
+}
+
+// AddUsage 原子累加用量(上传成功后调用)。
+func (s *SoraQuotaService) AddUsage(ctx context.Context, userID int64, bytes int64) error {
+ if bytes <= 0 {
+ return nil
+ }
+
+ quota, err := s.GetQuota(ctx, userID)
+ if err != nil {
+ return err
+ }
+
+ if quota.QuotaBytes > 0 && quota.UsedBytes+bytes > quota.QuotaBytes {
+ return &QuotaExceededError{
+ QuotaBytes: quota.QuotaBytes,
+ UsedBytes: quota.UsedBytes,
+ }
+ }
+
+ if repo, ok := s.userRepo.(soraQuotaAtomicUserRepository); ok {
+ newUsed, err := repo.AddSoraStorageUsageWithQuota(ctx, userID, bytes, quota.QuotaBytes)
+ if err != nil {
+ if errors.Is(err, ErrSoraStorageQuotaExceeded) {
+ return &QuotaExceededError{
+ QuotaBytes: quota.QuotaBytes,
+ UsedBytes: quota.UsedBytes,
+ }
+ }
+ return fmt.Errorf("update user quota usage (atomic): %w", err)
+ }
+ logger.LegacyPrintf("service.sora_quota", "[SoraQuota] 累加用量 user=%d +%d total=%d", userID, bytes, newUsed)
+ return nil
+ }
+
+ user, err := s.userRepo.GetByID(ctx, userID)
+ if err != nil {
+ return fmt.Errorf("get user for quota update: %w", err)
+ }
+ user.SoraStorageUsedBytes += bytes
+ if err := s.userRepo.Update(ctx, user); err != nil {
+ return fmt.Errorf("update user quota usage: %w", err)
+ }
+ logger.LegacyPrintf("service.sora_quota", "[SoraQuota] 累加用量 user=%d +%d total=%d", userID, bytes, user.SoraStorageUsedBytes)
+ return nil
+}
+
+// ReleaseUsage 释放用量(删除文件后调用)。
+func (s *SoraQuotaService) ReleaseUsage(ctx context.Context, userID int64, bytes int64) error {
+ if bytes <= 0 {
+ return nil
+ }
+
+ if repo, ok := s.userRepo.(soraQuotaAtomicUserRepository); ok {
+ newUsed, err := repo.ReleaseSoraStorageUsageAtomic(ctx, userID, bytes)
+ if err != nil {
+ return fmt.Errorf("update user quota release (atomic): %w", err)
+ }
+ logger.LegacyPrintf("service.sora_quota", "[SoraQuota] 释放用量 user=%d -%d total=%d", userID, bytes, newUsed)
+ return nil
+ }
+
+ user, err := s.userRepo.GetByID(ctx, userID)
+ if err != nil {
+ return fmt.Errorf("get user for quota release: %w", err)
+ }
+ user.SoraStorageUsedBytes -= bytes
+ if user.SoraStorageUsedBytes < 0 {
+ user.SoraStorageUsedBytes = 0
+ }
+ if err := s.userRepo.Update(ctx, user); err != nil {
+ return fmt.Errorf("update user quota release: %w", err)
+ }
+ logger.LegacyPrintf("service.sora_quota", "[SoraQuota] 释放用量 user=%d -%d total=%d", userID, bytes, user.SoraStorageUsedBytes)
+ return nil
+}
+
+func calcAvailableBytes(quotaBytes, usedBytes int64) int64 {
+ if quotaBytes <= 0 {
+ return 0
+ }
+ if usedBytes >= quotaBytes {
+ return 0
+ }
+ return quotaBytes - usedBytes
+}
+
+func (s *SoraQuotaService) getSystemDefaultQuota(ctx context.Context) int64 {
+ if s.settingService == nil {
+ return 0
+ }
+ settings, err := s.settingService.GetSoraS3Settings(ctx)
+ if err != nil {
+ return 0
+ }
+ return settings.DefaultStorageQuotaBytes
+}
+
+// GetQuotaFromSettings 从系统设置获取默认配额(供外部使用)。
+func (s *SoraQuotaService) GetQuotaFromSettings(ctx context.Context) int64 {
+ return s.getSystemDefaultQuota(ctx)
+}
+
+// SetUserQuota 设置用户级配额(管理员操作)。
+func SetUserSoraQuota(ctx context.Context, userRepo UserRepository, userID int64, quotaBytes int64) error {
+ user, err := userRepo.GetByID(ctx, userID)
+ if err != nil {
+ return err
+ }
+ user.SoraStorageQuotaBytes = quotaBytes
+ return userRepo.Update(ctx, user)
+}
+
+// ParseQuotaBytes 解析配额字符串为字节数。
+func ParseQuotaBytes(s string) int64 {
+ v, _ := strconv.ParseInt(s, 10, 64)
+ return v
+}
diff --git a/backend/internal/service/sora_quota_service_test.go b/backend/internal/service/sora_quota_service_test.go
new file mode 100644
index 00000000..040e427d
--- /dev/null
+++ b/backend/internal/service/sora_quota_service_test.go
@@ -0,0 +1,492 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "fmt"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
+ "github.com/stretchr/testify/require"
+)
+
+// ==================== Stub: GroupRepository (用于 SoraQuotaService) ====================
+
+var _ GroupRepository = (*stubGroupRepoForQuota)(nil)
+
+type stubGroupRepoForQuota struct {
+ groups map[int64]*Group
+}
+
+func newStubGroupRepoForQuota() *stubGroupRepoForQuota {
+ return &stubGroupRepoForQuota{groups: make(map[int64]*Group)}
+}
+
+func (r *stubGroupRepoForQuota) GetByID(_ context.Context, id int64) (*Group, error) {
+ if g, ok := r.groups[id]; ok {
+ return g, nil
+ }
+ return nil, fmt.Errorf("group not found")
+}
+func (r *stubGroupRepoForQuota) Create(context.Context, *Group) error { return nil }
+func (r *stubGroupRepoForQuota) GetByIDLite(_ context.Context, id int64) (*Group, error) {
+ return r.GetByID(context.Background(), id)
+}
+func (r *stubGroupRepoForQuota) Update(context.Context, *Group) error { return nil }
+func (r *stubGroupRepoForQuota) Delete(context.Context, int64) error { return nil }
+func (r *stubGroupRepoForQuota) DeleteCascade(context.Context, int64) ([]int64, error) {
+ return nil, nil
+}
+func (r *stubGroupRepoForQuota) List(context.Context, pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) {
+ return nil, nil, nil
+}
+func (r *stubGroupRepoForQuota) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, *bool) ([]Group, *pagination.PaginationResult, error) {
+ return nil, nil, nil
+}
+func (r *stubGroupRepoForQuota) ListActive(context.Context) ([]Group, error) { return nil, nil }
+func (r *stubGroupRepoForQuota) ListActiveByPlatform(context.Context, string) ([]Group, error) {
+ return nil, nil
+}
+func (r *stubGroupRepoForQuota) ExistsByName(context.Context, string) (bool, error) {
+ return false, nil
+}
+func (r *stubGroupRepoForQuota) GetAccountCount(context.Context, int64) (int64, error) {
+ return 0, nil
+}
+func (r *stubGroupRepoForQuota) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) {
+ return 0, nil
+}
+func (r *stubGroupRepoForQuota) GetAccountIDsByGroupIDs(context.Context, []int64) ([]int64, error) {
+ return nil, nil
+}
+func (r *stubGroupRepoForQuota) BindAccountsToGroup(context.Context, int64, []int64) error {
+ return nil
+}
+func (r *stubGroupRepoForQuota) UpdateSortOrders(context.Context, []GroupSortOrderUpdate) error {
+ return nil
+}
+
+// ==================== Stub: SettingRepository (用于 SettingService) ====================
+
+var _ SettingRepository = (*stubSettingRepoForQuota)(nil)
+
+type stubSettingRepoForQuota struct {
+ values map[string]string
+}
+
+func newStubSettingRepoForQuota(values map[string]string) *stubSettingRepoForQuota {
+ if values == nil {
+ values = make(map[string]string)
+ }
+ return &stubSettingRepoForQuota{values: values}
+}
+
+func (r *stubSettingRepoForQuota) Get(_ context.Context, key string) (*Setting, error) {
+ if v, ok := r.values[key]; ok {
+ return &Setting{Key: key, Value: v}, nil
+ }
+ return nil, ErrSettingNotFound
+}
+func (r *stubSettingRepoForQuota) GetValue(_ context.Context, key string) (string, error) {
+ if v, ok := r.values[key]; ok {
+ return v, nil
+ }
+ return "", ErrSettingNotFound
+}
+func (r *stubSettingRepoForQuota) Set(_ context.Context, key, value string) error {
+ r.values[key] = value
+ return nil
+}
+func (r *stubSettingRepoForQuota) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
+ result := make(map[string]string)
+ for _, k := range keys {
+ if v, ok := r.values[k]; ok {
+ result[k] = v
+ }
+ }
+ return result, nil
+}
+func (r *stubSettingRepoForQuota) SetMultiple(_ context.Context, settings map[string]string) error {
+ for k, v := range settings {
+ r.values[k] = v
+ }
+ return nil
+}
+func (r *stubSettingRepoForQuota) GetAll(_ context.Context) (map[string]string, error) {
+ return r.values, nil
+}
+func (r *stubSettingRepoForQuota) Delete(_ context.Context, key string) error {
+ delete(r.values, key)
+ return nil
+}
+
+// ==================== GetQuota ====================
+
+func TestGetQuota_UserLevel(t *testing.T) {
+ userRepo := newStubUserRepoForQuota()
+ userRepo.users[1] = &User{
+ ID: 1,
+ SoraStorageQuotaBytes: 10 * 1024 * 1024, // 10MB
+ SoraStorageUsedBytes: 3 * 1024 * 1024, // 3MB
+ }
+ svc := NewSoraQuotaService(userRepo, nil, nil)
+
+ quota, err := svc.GetQuota(context.Background(), 1)
+ require.NoError(t, err)
+ require.Equal(t, int64(10*1024*1024), quota.QuotaBytes)
+ require.Equal(t, int64(3*1024*1024), quota.UsedBytes)
+ require.Equal(t, "user", quota.Source)
+}
+
+func TestGetQuota_GroupLevel(t *testing.T) {
+ userRepo := newStubUserRepoForQuota()
+ userRepo.users[1] = &User{
+ ID: 1,
+ SoraStorageQuotaBytes: 0, // 用户级无配额
+ SoraStorageUsedBytes: 1024,
+ AllowedGroups: []int64{10, 20},
+ }
+
+ groupRepo := newStubGroupRepoForQuota()
+ groupRepo.groups[10] = &Group{ID: 10, SoraStorageQuotaBytes: 5 * 1024 * 1024}
+ groupRepo.groups[20] = &Group{ID: 20, SoraStorageQuotaBytes: 20 * 1024 * 1024}
+
+ svc := NewSoraQuotaService(userRepo, groupRepo, nil)
+ quota, err := svc.GetQuota(context.Background(), 1)
+ require.NoError(t, err)
+ require.Equal(t, int64(20*1024*1024), quota.QuotaBytes) // 取最大值
+ require.Equal(t, "group", quota.Source)
+}
+
+func TestGetQuota_SystemLevel(t *testing.T) {
+ userRepo := newStubUserRepoForQuota()
+ userRepo.users[1] = &User{ID: 1, SoraStorageQuotaBytes: 0, SoraStorageUsedBytes: 512}
+
+ settingRepo := newStubSettingRepoForQuota(map[string]string{
+ SettingKeySoraDefaultStorageQuotaBytes: "104857600", // 100MB
+ })
+ settingService := NewSettingService(settingRepo, &config.Config{})
+ svc := NewSoraQuotaService(userRepo, nil, settingService)
+
+ quota, err := svc.GetQuota(context.Background(), 1)
+ require.NoError(t, err)
+ require.Equal(t, int64(104857600), quota.QuotaBytes)
+ require.Equal(t, "system", quota.Source)
+}
+
+func TestGetQuota_NoLimit(t *testing.T) {
+ userRepo := newStubUserRepoForQuota()
+ userRepo.users[1] = &User{ID: 1, SoraStorageQuotaBytes: 0, SoraStorageUsedBytes: 0}
+ svc := NewSoraQuotaService(userRepo, nil, nil)
+
+ quota, err := svc.GetQuota(context.Background(), 1)
+ require.NoError(t, err)
+ require.Equal(t, int64(0), quota.QuotaBytes)
+ require.Equal(t, "unlimited", quota.Source)
+}
+
+func TestGetQuota_UserNotFound(t *testing.T) {
+ userRepo := newStubUserRepoForQuota()
+ svc := NewSoraQuotaService(userRepo, nil, nil)
+
+ _, err := svc.GetQuota(context.Background(), 999)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "get user")
+}
+
+func TestGetQuota_GroupRepoError(t *testing.T) {
+ // 分组获取失败时跳过该分组(不影响整体)
+ userRepo := newStubUserRepoForQuota()
+ userRepo.users[1] = &User{
+ ID: 1, SoraStorageQuotaBytes: 0,
+ AllowedGroups: []int64{999}, // 不存在的分组
+ }
+
+ groupRepo := newStubGroupRepoForQuota()
+ svc := NewSoraQuotaService(userRepo, groupRepo, nil)
+
+ quota, err := svc.GetQuota(context.Background(), 1)
+ require.NoError(t, err)
+ require.Equal(t, "unlimited", quota.Source) // 分组获取失败,回退到无限制
+}
+
+// ==================== CheckQuota ====================
+
+func TestCheckQuota_Sufficient(t *testing.T) {
+ userRepo := newStubUserRepoForQuota()
+ userRepo.users[1] = &User{
+ ID: 1,
+ SoraStorageQuotaBytes: 10 * 1024 * 1024,
+ SoraStorageUsedBytes: 3 * 1024 * 1024,
+ }
+ svc := NewSoraQuotaService(userRepo, nil, nil)
+
+ err := svc.CheckQuota(context.Background(), 1, 1024)
+ require.NoError(t, err)
+}
+
+func TestCheckQuota_Exceeded(t *testing.T) {
+ userRepo := newStubUserRepoForQuota()
+ userRepo.users[1] = &User{
+ ID: 1,
+ SoraStorageQuotaBytes: 10 * 1024 * 1024,
+ SoraStorageUsedBytes: 10 * 1024 * 1024, // 已满
+ }
+ svc := NewSoraQuotaService(userRepo, nil, nil)
+
+ err := svc.CheckQuota(context.Background(), 1, 1)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "配额不足")
+}
+
+func TestCheckQuota_NoLimit(t *testing.T) {
+ userRepo := newStubUserRepoForQuota()
+ userRepo.users[1] = &User{
+ ID: 1,
+ SoraStorageQuotaBytes: 0, // 无限制
+ SoraStorageUsedBytes: 1000000000,
+ }
+ svc := NewSoraQuotaService(userRepo, nil, nil)
+
+ err := svc.CheckQuota(context.Background(), 1, 999999999)
+ require.NoError(t, err) // 无限制时始终通过
+}
+
+func TestCheckQuota_ExactBoundary(t *testing.T) {
+ userRepo := newStubUserRepoForQuota()
+ userRepo.users[1] = &User{
+ ID: 1,
+ SoraStorageQuotaBytes: 1024,
+ SoraStorageUsedBytes: 1024, // 恰好满
+ }
+ svc := NewSoraQuotaService(userRepo, nil, nil)
+
+ // 额外 0 字节不超
+ require.NoError(t, svc.CheckQuota(context.Background(), 1, 0))
+ // 额外 1 字节超出
+ require.Error(t, svc.CheckQuota(context.Background(), 1, 1))
+}
+
+// ==================== AddUsage ====================
+
+func TestAddUsage_Success(t *testing.T) {
+ userRepo := newStubUserRepoForQuota()
+ userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 1024}
+ svc := NewSoraQuotaService(userRepo, nil, nil)
+
+ err := svc.AddUsage(context.Background(), 1, 2048)
+ require.NoError(t, err)
+ require.Equal(t, int64(3072), userRepo.users[1].SoraStorageUsedBytes)
+}
+
+func TestAddUsage_ZeroBytes(t *testing.T) {
+ userRepo := newStubUserRepoForQuota()
+ userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 1024}
+ svc := NewSoraQuotaService(userRepo, nil, nil)
+
+ err := svc.AddUsage(context.Background(), 1, 0)
+ require.NoError(t, err)
+ require.Equal(t, int64(1024), userRepo.users[1].SoraStorageUsedBytes) // 不变
+}
+
+func TestAddUsage_NegativeBytes(t *testing.T) {
+ userRepo := newStubUserRepoForQuota()
+ userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 1024}
+ svc := NewSoraQuotaService(userRepo, nil, nil)
+
+ err := svc.AddUsage(context.Background(), 1, -100)
+ require.NoError(t, err)
+ require.Equal(t, int64(1024), userRepo.users[1].SoraStorageUsedBytes) // 不变
+}
+
+func TestAddUsage_UserNotFound(t *testing.T) {
+ userRepo := newStubUserRepoForQuota()
+ svc := NewSoraQuotaService(userRepo, nil, nil)
+
+ err := svc.AddUsage(context.Background(), 999, 1024)
+ require.Error(t, err)
+}
+
+func TestAddUsage_UpdateError(t *testing.T) {
+ userRepo := newStubUserRepoForQuota()
+ userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 0}
+ userRepo.updateErr = fmt.Errorf("db error")
+ svc := NewSoraQuotaService(userRepo, nil, nil)
+
+ err := svc.AddUsage(context.Background(), 1, 1024)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "update user quota usage")
+}
+
+// ==================== ReleaseUsage ====================
+
+func TestReleaseUsage_Success(t *testing.T) {
+ userRepo := newStubUserRepoForQuota()
+ userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 3072}
+ svc := NewSoraQuotaService(userRepo, nil, nil)
+
+ err := svc.ReleaseUsage(context.Background(), 1, 1024)
+ require.NoError(t, err)
+ require.Equal(t, int64(2048), userRepo.users[1].SoraStorageUsedBytes)
+}
+
+func TestReleaseUsage_ClampToZero(t *testing.T) {
+ // 释放量大于已用量时,应 clamp 到 0
+ userRepo := newStubUserRepoForQuota()
+ userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 500}
+ svc := NewSoraQuotaService(userRepo, nil, nil)
+
+ err := svc.ReleaseUsage(context.Background(), 1, 1000)
+ require.NoError(t, err)
+ require.Equal(t, int64(0), userRepo.users[1].SoraStorageUsedBytes)
+}
+
+func TestReleaseUsage_ZeroBytes(t *testing.T) {
+ userRepo := newStubUserRepoForQuota()
+ userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 1024}
+ svc := NewSoraQuotaService(userRepo, nil, nil)
+
+ err := svc.ReleaseUsage(context.Background(), 1, 0)
+ require.NoError(t, err)
+ require.Equal(t, int64(1024), userRepo.users[1].SoraStorageUsedBytes) // 不变
+}
+
+func TestReleaseUsage_NegativeBytes(t *testing.T) {
+ userRepo := newStubUserRepoForQuota()
+ userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 1024}
+ svc := NewSoraQuotaService(userRepo, nil, nil)
+
+ err := svc.ReleaseUsage(context.Background(), 1, -50)
+ require.NoError(t, err)
+ require.Equal(t, int64(1024), userRepo.users[1].SoraStorageUsedBytes) // 不变
+}
+
+func TestReleaseUsage_UserNotFound(t *testing.T) {
+ userRepo := newStubUserRepoForQuota()
+ svc := NewSoraQuotaService(userRepo, nil, nil)
+
+ err := svc.ReleaseUsage(context.Background(), 999, 1024)
+ require.Error(t, err)
+}
+
+func TestReleaseUsage_UpdateError(t *testing.T) {
+ userRepo := newStubUserRepoForQuota()
+ userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 1024}
+ userRepo.updateErr = fmt.Errorf("db error")
+ svc := NewSoraQuotaService(userRepo, nil, nil)
+
+ err := svc.ReleaseUsage(context.Background(), 1, 512)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "update user quota release")
+}
+
+// ==================== GetQuotaFromSettings ====================
+
+func TestGetQuotaFromSettings_NilSettingService(t *testing.T) {
+ svc := NewSoraQuotaService(nil, nil, nil)
+ require.Equal(t, int64(0), svc.GetQuotaFromSettings(context.Background()))
+}
+
+func TestGetQuotaFromSettings_WithSettings(t *testing.T) {
+ settingRepo := newStubSettingRepoForQuota(map[string]string{
+ SettingKeySoraDefaultStorageQuotaBytes: "52428800", // 50MB
+ })
+ settingService := NewSettingService(settingRepo, &config.Config{})
+ svc := NewSoraQuotaService(nil, nil, settingService)
+
+ require.Equal(t, int64(52428800), svc.GetQuotaFromSettings(context.Background()))
+}
+
+// ==================== SetUserSoraQuota ====================
+
+func TestSetUserSoraQuota_Success(t *testing.T) {
+ userRepo := newStubUserRepoForQuota()
+ userRepo.users[1] = &User{ID: 1, SoraStorageQuotaBytes: 0}
+
+ err := SetUserSoraQuota(context.Background(), userRepo, 1, 10*1024*1024)
+ require.NoError(t, err)
+ require.Equal(t, int64(10*1024*1024), userRepo.users[1].SoraStorageQuotaBytes)
+}
+
+func TestSetUserSoraQuota_UserNotFound(t *testing.T) {
+ userRepo := newStubUserRepoForQuota()
+ err := SetUserSoraQuota(context.Background(), userRepo, 999, 1024)
+ require.Error(t, err)
+}
+
+// ==================== ParseQuotaBytes ====================
+
+func TestParseQuotaBytes(t *testing.T) {
+ require.Equal(t, int64(1048576), ParseQuotaBytes("1048576"))
+ require.Equal(t, int64(0), ParseQuotaBytes(""))
+ require.Equal(t, int64(0), ParseQuotaBytes("abc"))
+ require.Equal(t, int64(-1), ParseQuotaBytes("-1"))
+}
+
+// ==================== 优先级完整测试 ====================
+
+func TestQuotaPriority_UserOverridesGroup(t *testing.T) {
+ userRepo := newStubUserRepoForQuota()
+ userRepo.users[1] = &User{
+ ID: 1,
+ SoraStorageQuotaBytes: 5 * 1024 * 1024,
+ AllowedGroups: []int64{10},
+ }
+
+ groupRepo := newStubGroupRepoForQuota()
+ groupRepo.groups[10] = &Group{ID: 10, SoraStorageQuotaBytes: 20 * 1024 * 1024}
+
+ svc := NewSoraQuotaService(userRepo, groupRepo, nil)
+ quota, err := svc.GetQuota(context.Background(), 1)
+ require.NoError(t, err)
+ require.Equal(t, "user", quota.Source) // 用户级优先
+ require.Equal(t, int64(5*1024*1024), quota.QuotaBytes)
+}
+
+func TestQuotaPriority_GroupOverridesSystem(t *testing.T) {
+ userRepo := newStubUserRepoForQuota()
+ userRepo.users[1] = &User{
+ ID: 1,
+ SoraStorageQuotaBytes: 0,
+ AllowedGroups: []int64{10},
+ }
+
+ groupRepo := newStubGroupRepoForQuota()
+ groupRepo.groups[10] = &Group{ID: 10, SoraStorageQuotaBytes: 20 * 1024 * 1024}
+
+ settingRepo := newStubSettingRepoForQuota(map[string]string{
+ SettingKeySoraDefaultStorageQuotaBytes: "104857600", // 100MB
+ })
+ settingService := NewSettingService(settingRepo, &config.Config{})
+
+ svc := NewSoraQuotaService(userRepo, groupRepo, settingService)
+ quota, err := svc.GetQuota(context.Background(), 1)
+ require.NoError(t, err)
+ require.Equal(t, "group", quota.Source) // 分组级优先于系统
+ require.Equal(t, int64(20*1024*1024), quota.QuotaBytes)
+}
+
+func TestQuotaPriority_FallbackToSystem(t *testing.T) {
+ userRepo := newStubUserRepoForQuota()
+ userRepo.users[1] = &User{
+ ID: 1,
+ SoraStorageQuotaBytes: 0,
+ AllowedGroups: []int64{10},
+ }
+
+ groupRepo := newStubGroupRepoForQuota()
+ groupRepo.groups[10] = &Group{ID: 10, SoraStorageQuotaBytes: 0} // 分组无配额
+
+ settingRepo := newStubSettingRepoForQuota(map[string]string{
+ SettingKeySoraDefaultStorageQuotaBytes: "52428800", // 50MB
+ })
+ settingService := NewSettingService(settingRepo, &config.Config{})
+
+ svc := NewSoraQuotaService(userRepo, groupRepo, settingService)
+ quota, err := svc.GetQuota(context.Background(), 1)
+ require.NoError(t, err)
+ require.Equal(t, "system", quota.Source)
+ require.Equal(t, int64(52428800), quota.QuotaBytes)
+}
diff --git a/backend/internal/service/sora_s3_storage.go b/backend/internal/service/sora_s3_storage.go
new file mode 100644
index 00000000..4c573905
--- /dev/null
+++ b/backend/internal/service/sora_s3_storage.go
@@ -0,0 +1,392 @@
+package service
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "net/http"
+ "path"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/aws/aws-sdk-go-v2/aws"
+ v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
+ awsconfig "github.com/aws/aws-sdk-go-v2/config"
+ "github.com/aws/aws-sdk-go-v2/credentials"
+ "github.com/aws/aws-sdk-go-v2/service/s3"
+ "github.com/google/uuid"
+
+ "github.com/Wei-Shaw/sub2api/internal/pkg/logger"
+)
+
+// SoraS3Storage 负责 Sora 媒体文件的 S3 存储操作。
+// 从 Settings 表读取 S3 配置,初始化并缓存 S3 客户端。
+type SoraS3Storage struct {
+ settingService *SettingService
+
+ mu sync.RWMutex
+ client *s3.Client
+ cfg *SoraS3Settings // 上次加载的配置快照
+
+ healthCheckedAt time.Time
+ healthErr error
+ healthTTL time.Duration
+}
+
+const defaultSoraS3HealthTTL = 30 * time.Second
+
+// UpstreamDownloadError 表示从上游下载媒体失败(包含 HTTP 状态码)。
+type UpstreamDownloadError struct {
+ StatusCode int
+}
+
+func (e *UpstreamDownloadError) Error() string {
+ if e == nil {
+ return "upstream download failed"
+ }
+ return fmt.Sprintf("upstream returned %d", e.StatusCode)
+}
+
+// NewSoraS3Storage 创建 S3 存储服务实例。
+func NewSoraS3Storage(settingService *SettingService) *SoraS3Storage {
+ return &SoraS3Storage{
+ settingService: settingService,
+ healthTTL: defaultSoraS3HealthTTL,
+ }
+}
+
+// Enabled 返回 S3 存储是否已启用且配置有效。
+func (s *SoraS3Storage) Enabled(ctx context.Context) bool {
+ cfg, err := s.getConfig(ctx)
+ if err != nil || cfg == nil {
+ return false
+ }
+ return cfg.Enabled && cfg.Bucket != ""
+}
+
+// getConfig 获取当前 S3 配置(从 settings 表读取)。
+func (s *SoraS3Storage) getConfig(ctx context.Context) (*SoraS3Settings, error) {
+ if s.settingService == nil {
+ return nil, fmt.Errorf("setting service not available")
+ }
+ return s.settingService.GetSoraS3Settings(ctx)
+}
+
+// getClient 获取或初始化 S3 客户端(带缓存)。
+// 配置变更时调用 RefreshClient 清除缓存。
+func (s *SoraS3Storage) getClient(ctx context.Context) (*s3.Client, *SoraS3Settings, error) {
+ s.mu.RLock()
+ if s.client != nil && s.cfg != nil {
+ client, cfg := s.client, s.cfg
+ s.mu.RUnlock()
+ return client, cfg, nil
+ }
+ s.mu.RUnlock()
+
+ return s.initClient(ctx)
+}
+
+func (s *SoraS3Storage) initClient(ctx context.Context) (*s3.Client, *SoraS3Settings, error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ // 双重检查
+ if s.client != nil && s.cfg != nil {
+ return s.client, s.cfg, nil
+ }
+
+ cfg, err := s.getConfig(ctx)
+ if err != nil {
+ return nil, nil, fmt.Errorf("load s3 config: %w", err)
+ }
+ if !cfg.Enabled {
+ return nil, nil, fmt.Errorf("sora s3 storage is disabled")
+ }
+ if cfg.Bucket == "" || cfg.AccessKeyID == "" || cfg.SecretAccessKey == "" {
+ return nil, nil, fmt.Errorf("sora s3 config incomplete: bucket, access_key_id, secret_access_key are required")
+ }
+
+ client, region, err := buildSoraS3Client(ctx, cfg)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ s.client = client
+ s.cfg = cfg
+ logger.LegacyPrintf("service.sora_s3", "[SoraS3] 客户端已初始化 bucket=%s endpoint=%s region=%s", cfg.Bucket, cfg.Endpoint, region)
+ return client, cfg, nil
+}
+
+// RefreshClient 清除缓存的 S3 客户端,下次使用时重新初始化。
+// 应在系统设置中 S3 配置变更时调用。
+func (s *SoraS3Storage) RefreshClient() {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ s.client = nil
+ s.cfg = nil
+ s.healthCheckedAt = time.Time{}
+ s.healthErr = nil
+ logger.LegacyPrintf("service.sora_s3", "[SoraS3] 客户端缓存已清除,下次使用将重新初始化")
+}
+
+// TestConnection 测试 S3 连接(HeadBucket)。
+func (s *SoraS3Storage) TestConnection(ctx context.Context) error {
+ client, cfg, err := s.getClient(ctx)
+ if err != nil {
+ return err
+ }
+ _, err = client.HeadBucket(ctx, &s3.HeadBucketInput{
+ Bucket: &cfg.Bucket,
+ })
+ if err != nil {
+ return fmt.Errorf("s3 HeadBucket failed: %w", err)
+ }
+ return nil
+}
+
+// IsHealthy 返回 S3 健康状态(带短缓存,避免每次请求都触发 HeadBucket)。
+func (s *SoraS3Storage) IsHealthy(ctx context.Context) bool {
+ if s == nil {
+ return false
+ }
+ now := time.Now()
+ s.mu.RLock()
+ lastCheck := s.healthCheckedAt
+ lastErr := s.healthErr
+ ttl := s.healthTTL
+ s.mu.RUnlock()
+
+ if ttl <= 0 {
+ ttl = defaultSoraS3HealthTTL
+ }
+ if !lastCheck.IsZero() && now.Sub(lastCheck) < ttl {
+ return lastErr == nil
+ }
+
+ err := s.TestConnection(ctx)
+ s.mu.Lock()
+ s.healthCheckedAt = time.Now()
+ s.healthErr = err
+ s.mu.Unlock()
+ return err == nil
+}
+
+// TestConnectionWithSettings 使用临时配置测试连接,不污染缓存的客户端。
+func (s *SoraS3Storage) TestConnectionWithSettings(ctx context.Context, cfg *SoraS3Settings) error {
+ if cfg == nil {
+ return fmt.Errorf("s3 config is required")
+ }
+ if !cfg.Enabled {
+ return fmt.Errorf("sora s3 storage is disabled")
+ }
+ if cfg.Endpoint == "" || cfg.Bucket == "" || cfg.AccessKeyID == "" || cfg.SecretAccessKey == "" {
+ return fmt.Errorf("sora s3 config incomplete: endpoint, bucket, access_key_id, secret_access_key are required")
+ }
+ client, _, err := buildSoraS3Client(ctx, cfg)
+ if err != nil {
+ return err
+ }
+ _, err = client.HeadBucket(ctx, &s3.HeadBucketInput{
+ Bucket: &cfg.Bucket,
+ })
+ if err != nil {
+ return fmt.Errorf("s3 HeadBucket failed: %w", err)
+ }
+ return nil
+}
+
+// GenerateObjectKey 生成 S3 object key。
+// 格式: {prefix}sora/{userID}/{YYYY/MM/DD}/{uuid}.{ext}
+func (s *SoraS3Storage) GenerateObjectKey(prefix string, userID int64, ext string) string {
+ if !strings.HasPrefix(ext, ".") {
+ ext = "." + ext
+ }
+ datePath := time.Now().Format("2006/01/02")
+ key := fmt.Sprintf("sora/%d/%s/%s%s", userID, datePath, uuid.NewString(), ext)
+ if prefix != "" {
+ prefix = strings.TrimRight(prefix, "/") + "/"
+ key = prefix + key
+ }
+ return key
+}
+
+// UploadFromURL 从上游 URL 下载并流式上传到 S3。
+// 返回 S3 object key。
+func (s *SoraS3Storage) UploadFromURL(ctx context.Context, userID int64, sourceURL string) (string, int64, error) {
+ client, cfg, err := s.getClient(ctx)
+ if err != nil {
+ return "", 0, err
+ }
+
+ // 下载源文件
+ req, err := http.NewRequestWithContext(ctx, http.MethodGet, sourceURL, nil)
+ if err != nil {
+ return "", 0, fmt.Errorf("create download request: %w", err)
+ }
+ httpClient := &http.Client{Timeout: 5 * time.Minute}
+ resp, err := httpClient.Do(req)
+ if err != nil {
+ return "", 0, fmt.Errorf("download from upstream: %w", err)
+ }
+ defer func() {
+ _ = resp.Body.Close()
+ }()
+
+ if resp.StatusCode != http.StatusOK {
+ return "", 0, &UpstreamDownloadError{StatusCode: resp.StatusCode}
+ }
+
+ // 推断文件扩展名
+ ext := fileExtFromURL(sourceURL)
+ if ext == "" {
+ ext = fileExtFromContentType(resp.Header.Get("Content-Type"))
+ }
+ if ext == "" {
+ ext = ".bin"
+ }
+
+ objectKey := s.GenerateObjectKey(cfg.Prefix, userID, ext)
+
+ // 检测 Content-Type
+ contentType := resp.Header.Get("Content-Type")
+ if contentType == "" {
+ contentType = "application/octet-stream"
+ }
+
+ reader, writer := io.Pipe()
+ uploadErrCh := make(chan error, 1)
+ go func() {
+ defer close(uploadErrCh)
+ input := &s3.PutObjectInput{
+ Bucket: &cfg.Bucket,
+ Key: &objectKey,
+ Body: reader,
+ ContentType: &contentType,
+ }
+ if resp.ContentLength >= 0 {
+ input.ContentLength = &resp.ContentLength
+ }
+ _, uploadErr := client.PutObject(ctx, input)
+ uploadErrCh <- uploadErr
+ }()
+
+ written, copyErr := io.CopyBuffer(writer, resp.Body, make([]byte, 1024*1024))
+ _ = writer.CloseWithError(copyErr)
+ uploadErr := <-uploadErrCh
+ if copyErr != nil {
+ return "", 0, fmt.Errorf("stream upload copy failed: %w", copyErr)
+ }
+ if uploadErr != nil {
+ return "", 0, fmt.Errorf("s3 upload: %w", uploadErr)
+ }
+
+ logger.LegacyPrintf("service.sora_s3", "[SoraS3] 上传完成 key=%s size=%d", objectKey, written)
+ return objectKey, written, nil
+}
+
+func buildSoraS3Client(ctx context.Context, cfg *SoraS3Settings) (*s3.Client, string, error) {
+ if cfg == nil {
+ return nil, "", fmt.Errorf("s3 config is required")
+ }
+ region := cfg.Region
+ if region == "" {
+ region = "us-east-1"
+ }
+
+ awsCfg, err := awsconfig.LoadDefaultConfig(ctx,
+ awsconfig.WithRegion(region),
+ awsconfig.WithCredentialsProvider(
+ credentials.NewStaticCredentialsProvider(cfg.AccessKeyID, cfg.SecretAccessKey, ""),
+ ),
+ )
+ if err != nil {
+ return nil, "", fmt.Errorf("load aws config: %w", err)
+ }
+
+ client := s3.NewFromConfig(awsCfg, func(o *s3.Options) {
+ if cfg.Endpoint != "" {
+ o.BaseEndpoint = &cfg.Endpoint
+ }
+ if cfg.ForcePathStyle {
+ o.UsePathStyle = true
+ }
+ o.APIOptions = append(o.APIOptions, v4.SwapComputePayloadSHA256ForUnsignedPayloadMiddleware)
+ // 兼容非 TLS 连接(如 MinIO)的流式上传,避免 io.Pipe checksum 校验失败
+ o.RequestChecksumCalculation = aws.RequestChecksumCalculationWhenRequired
+ })
+ return client, region, nil
+}
+
+// DeleteObjects 删除一组 S3 object(遍历逐一删除)。
+func (s *SoraS3Storage) DeleteObjects(ctx context.Context, objectKeys []string) error {
+ if len(objectKeys) == 0 {
+ return nil
+ }
+
+ client, cfg, err := s.getClient(ctx)
+ if err != nil {
+ return err
+ }
+
+ var lastErr error
+ for _, key := range objectKeys {
+ k := key
+ _, err := client.DeleteObject(ctx, &s3.DeleteObjectInput{
+ Bucket: &cfg.Bucket,
+ Key: &k,
+ })
+ if err != nil {
+ logger.LegacyPrintf("service.sora_s3", "[SoraS3] 删除失败 key=%s err=%v", key, err)
+ lastErr = err
+ }
+ }
+ return lastErr
+}
+
+// GetAccessURL 获取 S3 文件的访问 URL。
+// CDN URL 优先,否则生成 24h 预签名 URL。
+func (s *SoraS3Storage) GetAccessURL(ctx context.Context, objectKey string) (string, error) {
+ _, cfg, err := s.getClient(ctx)
+ if err != nil {
+ return "", err
+ }
+
+ // CDN URL 优先
+ if cfg.CDNURL != "" {
+ cdnBase := strings.TrimRight(cfg.CDNURL, "/")
+ return cdnBase + "/" + objectKey, nil
+ }
+
+ // 生成 24h 预签名 URL
+ return s.GeneratePresignedURL(ctx, objectKey, 24*time.Hour)
+}
+
+// GeneratePresignedURL 生成预签名 URL。
+func (s *SoraS3Storage) GeneratePresignedURL(ctx context.Context, objectKey string, ttl time.Duration) (string, error) {
+ client, cfg, err := s.getClient(ctx)
+ if err != nil {
+ return "", err
+ }
+
+ presignClient := s3.NewPresignClient(client)
+ result, err := presignClient.PresignGetObject(ctx, &s3.GetObjectInput{
+ Bucket: &cfg.Bucket,
+ Key: &objectKey,
+ }, s3.WithPresignExpires(ttl))
+ if err != nil {
+ return "", fmt.Errorf("presign url: %w", err)
+ }
+ return result.URL, nil
+}
+
+// GetMediaType 从 object key 推断媒体类型(image/video)。
+func GetMediaTypeFromKey(objectKey string) string {
+ ext := strings.ToLower(path.Ext(objectKey))
+ switch ext {
+ case ".mp4", ".mov", ".webm", ".m4v", ".avi", ".mkv", ".3gp", ".flv":
+ return "video"
+ default:
+ return "image"
+ }
+}
diff --git a/backend/internal/service/sora_s3_storage_test.go b/backend/internal/service/sora_s3_storage_test.go
new file mode 100644
index 00000000..32ff9a6f
--- /dev/null
+++ b/backend/internal/service/sora_s3_storage_test.go
@@ -0,0 +1,263 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/stretchr/testify/require"
+)
+
+// ==================== RefreshClient ====================
+
+func TestRefreshClient(t *testing.T) {
+ s := newS3StorageWithCDN("https://cdn.example.com")
+ require.NotNil(t, s.client)
+ require.NotNil(t, s.cfg)
+
+ s.RefreshClient()
+ require.Nil(t, s.client)
+ require.Nil(t, s.cfg)
+}
+
+func TestRefreshClient_AlreadyNil(t *testing.T) {
+ s := NewSoraS3Storage(nil)
+ s.RefreshClient() // 不应 panic
+ require.Nil(t, s.client)
+ require.Nil(t, s.cfg)
+}
+
+// ==================== GetMediaTypeFromKey ====================
+
+func TestGetMediaTypeFromKey_VideoExtensions(t *testing.T) {
+ for _, ext := range []string{".mp4", ".mov", ".webm", ".m4v", ".avi", ".mkv", ".3gp", ".flv"} {
+ require.Equal(t, "video", GetMediaTypeFromKey("path/to/file"+ext), "ext=%s", ext)
+ }
+}
+
+func TestGetMediaTypeFromKey_VideoUpperCase(t *testing.T) {
+ require.Equal(t, "video", GetMediaTypeFromKey("file.MP4"))
+ require.Equal(t, "video", GetMediaTypeFromKey("file.MOV"))
+}
+
+func TestGetMediaTypeFromKey_ImageExtensions(t *testing.T) {
+ require.Equal(t, "image", GetMediaTypeFromKey("file.png"))
+ require.Equal(t, "image", GetMediaTypeFromKey("file.jpg"))
+ require.Equal(t, "image", GetMediaTypeFromKey("file.jpeg"))
+ require.Equal(t, "image", GetMediaTypeFromKey("file.gif"))
+ require.Equal(t, "image", GetMediaTypeFromKey("file.webp"))
+}
+
+func TestGetMediaTypeFromKey_NoExtension(t *testing.T) {
+ require.Equal(t, "image", GetMediaTypeFromKey("file"))
+ require.Equal(t, "image", GetMediaTypeFromKey("path/to/file"))
+}
+
+func TestGetMediaTypeFromKey_UnknownExtension(t *testing.T) {
+ require.Equal(t, "image", GetMediaTypeFromKey("file.bin"))
+ require.Equal(t, "image", GetMediaTypeFromKey("file.xyz"))
+}
+
+// ==================== Enabled ====================
+
+func TestEnabled_NilSettingService(t *testing.T) {
+ s := NewSoraS3Storage(nil)
+ require.False(t, s.Enabled(context.Background()))
+}
+
+func TestEnabled_ConfigDisabled(t *testing.T) {
+ settingRepo := newStubSettingRepoForQuota(map[string]string{
+ SettingKeySoraS3Enabled: "false",
+ SettingKeySoraS3Bucket: "test-bucket",
+ })
+ settingService := NewSettingService(settingRepo, &config.Config{})
+ s := NewSoraS3Storage(settingService)
+ require.False(t, s.Enabled(context.Background()))
+}
+
+func TestEnabled_ConfigEnabledWithBucket(t *testing.T) {
+ settingRepo := newStubSettingRepoForQuota(map[string]string{
+ SettingKeySoraS3Enabled: "true",
+ SettingKeySoraS3Bucket: "my-bucket",
+ })
+ settingService := NewSettingService(settingRepo, &config.Config{})
+ s := NewSoraS3Storage(settingService)
+ require.True(t, s.Enabled(context.Background()))
+}
+
+func TestEnabled_ConfigEnabledEmptyBucket(t *testing.T) {
+ settingRepo := newStubSettingRepoForQuota(map[string]string{
+ SettingKeySoraS3Enabled: "true",
+ })
+ settingService := NewSettingService(settingRepo, &config.Config{})
+ s := NewSoraS3Storage(settingService)
+ require.False(t, s.Enabled(context.Background()))
+}
+
+// ==================== initClient ====================
+
+func TestInitClient_Disabled(t *testing.T) {
+ settingRepo := newStubSettingRepoForQuota(map[string]string{
+ SettingKeySoraS3Enabled: "false",
+ })
+ settingService := NewSettingService(settingRepo, &config.Config{})
+ s := NewSoraS3Storage(settingService)
+
+ _, _, err := s.getClient(context.Background())
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "disabled")
+}
+
+func TestInitClient_IncompleteConfig(t *testing.T) {
+ settingRepo := newStubSettingRepoForQuota(map[string]string{
+ SettingKeySoraS3Enabled: "true",
+ SettingKeySoraS3Bucket: "test-bucket",
+ // 缺少 access_key_id 和 secret_access_key
+ })
+ settingService := NewSettingService(settingRepo, &config.Config{})
+ s := NewSoraS3Storage(settingService)
+
+ _, _, err := s.getClient(context.Background())
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "incomplete")
+}
+
+func TestInitClient_DefaultRegion(t *testing.T) {
+ settingRepo := newStubSettingRepoForQuota(map[string]string{
+ SettingKeySoraS3Enabled: "true",
+ SettingKeySoraS3Bucket: "test-bucket",
+ SettingKeySoraS3AccessKeyID: "AKID",
+ SettingKeySoraS3SecretAccessKey: "SECRET",
+ // Region 为空 → 默认 us-east-1
+ })
+ settingService := NewSettingService(settingRepo, &config.Config{})
+ s := NewSoraS3Storage(settingService)
+
+ client, cfg, err := s.getClient(context.Background())
+ require.NoError(t, err)
+ require.NotNil(t, client)
+ require.Equal(t, "test-bucket", cfg.Bucket)
+}
+
+func TestInitClient_DoubleCheck(t *testing.T) {
+ // 验证双重检查锁定:第二次 getClient 命中缓存
+ settingRepo := newStubSettingRepoForQuota(map[string]string{
+ SettingKeySoraS3Enabled: "true",
+ SettingKeySoraS3Bucket: "test-bucket",
+ SettingKeySoraS3AccessKeyID: "AKID",
+ SettingKeySoraS3SecretAccessKey: "SECRET",
+ })
+ settingService := NewSettingService(settingRepo, &config.Config{})
+ s := NewSoraS3Storage(settingService)
+
+ client1, _, err1 := s.getClient(context.Background())
+ require.NoError(t, err1)
+ client2, _, err2 := s.getClient(context.Background())
+ require.NoError(t, err2)
+ require.Equal(t, client1, client2) // 同一客户端实例
+}
+
+func TestInitClient_NilSettingService(t *testing.T) {
+ s := NewSoraS3Storage(nil)
+ _, _, err := s.getClient(context.Background())
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "setting service not available")
+}
+
+// ==================== GenerateObjectKey ====================
+
+func TestGenerateObjectKey_ExtWithoutDot(t *testing.T) {
+ s := NewSoraS3Storage(nil)
+ key := s.GenerateObjectKey("", 1, "mp4")
+ require.Contains(t, key, ".mp4")
+ require.True(t, len(key) > 0)
+}
+
+func TestGenerateObjectKey_ExtWithDot(t *testing.T) {
+ s := NewSoraS3Storage(nil)
+ key := s.GenerateObjectKey("", 1, ".mp4")
+ require.Contains(t, key, ".mp4")
+ // 不应出现 ..mp4
+ require.NotContains(t, key, "..mp4")
+}
+
+func TestGenerateObjectKey_WithPrefix(t *testing.T) {
+ s := NewSoraS3Storage(nil)
+ key := s.GenerateObjectKey("uploads/", 42, ".png")
+ require.True(t, len(key) > 0)
+ require.Contains(t, key, "uploads/sora/42/")
+}
+
+func TestGenerateObjectKey_PrefixWithoutTrailingSlash(t *testing.T) {
+ s := NewSoraS3Storage(nil)
+ key := s.GenerateObjectKey("uploads", 42, ".png")
+ require.Contains(t, key, "uploads/sora/42/")
+}
+
+// ==================== GeneratePresignedURL ====================
+
+func TestGeneratePresignedURL_GetClientError(t *testing.T) {
+ s := NewSoraS3Storage(nil) // settingService=nil → getClient 失败
+ _, err := s.GeneratePresignedURL(context.Background(), "key", 3600)
+ require.Error(t, err)
+}
+
+// ==================== GetAccessURL ====================
+
+func TestGetAccessURL_CDN(t *testing.T) {
+ s := newS3StorageWithCDN("https://cdn.example.com")
+ url, err := s.GetAccessURL(context.Background(), "sora/1/2024/01/01/video.mp4")
+ require.NoError(t, err)
+ require.Equal(t, "https://cdn.example.com/sora/1/2024/01/01/video.mp4", url)
+}
+
+func TestGetAccessURL_CDNTrailingSlash(t *testing.T) {
+ s := newS3StorageWithCDN("https://cdn.example.com/")
+ url, err := s.GetAccessURL(context.Background(), "key.mp4")
+ require.NoError(t, err)
+ require.Equal(t, "https://cdn.example.com/key.mp4", url)
+}
+
+func TestGetAccessURL_GetClientError(t *testing.T) {
+ s := NewSoraS3Storage(nil)
+ _, err := s.GetAccessURL(context.Background(), "key")
+ require.Error(t, err)
+}
+
+// ==================== TestConnection ====================
+
+func TestTestConnection_GetClientError(t *testing.T) {
+ s := NewSoraS3Storage(nil)
+ err := s.TestConnection(context.Background())
+ require.Error(t, err)
+}
+
+// ==================== UploadFromURL ====================
+
+func TestUploadFromURL_GetClientError(t *testing.T) {
+ s := NewSoraS3Storage(nil)
+ _, _, err := s.UploadFromURL(context.Background(), 1, "https://example.com/file.mp4")
+ require.Error(t, err)
+}
+
+// ==================== DeleteObjects ====================
+
+func TestDeleteObjects_EmptyKeys(t *testing.T) {
+ s := NewSoraS3Storage(nil)
+ err := s.DeleteObjects(context.Background(), []string{})
+ require.NoError(t, err) // 空列表直接返回
+}
+
+func TestDeleteObjects_NilKeys(t *testing.T) {
+ s := NewSoraS3Storage(nil)
+ err := s.DeleteObjects(context.Background(), nil)
+ require.NoError(t, err) // nil 列表直接返回
+}
+
+func TestDeleteObjects_GetClientError(t *testing.T) {
+ s := NewSoraS3Storage(nil)
+ err := s.DeleteObjects(context.Background(), []string{"key1", "key2"})
+ require.Error(t, err)
+}
diff --git a/backend/internal/service/sora_sdk_client.go b/backend/internal/service/sora_sdk_client.go
index 604c2749..f9221c5b 100644
--- a/backend/internal/service/sora_sdk_client.go
+++ b/backend/internal/service/sora_sdk_client.go
@@ -15,6 +15,7 @@ import (
"github.com/DouDOU-start/go-sora2api/sora"
"github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/logger"
openaioauth "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/util/logredact"
"github.com/tidwall/gjson"
@@ -75,6 +76,17 @@ func (c *SoraSDKClient) PreflightCheck(ctx context.Context, account *Account, re
}
balance, err := sdkClient.GetCreditBalance(ctx, token)
if err != nil {
+ accountID := int64(0)
+ if account != nil {
+ accountID = account.ID
+ }
+ logger.LegacyPrintf(
+ "service.sora_sdk",
+ "[PreflightCheckRawError] account_id=%d model=%s op=get_credit_balance raw_err=%s",
+ accountID,
+ requestedModel,
+ logredact.RedactText(err.Error()),
+ )
return &SoraUpstreamError{
StatusCode: http.StatusForbidden,
Message: "当前账号未开通 Sora2 能力或无可用配额",
@@ -170,9 +182,23 @@ func (c *SoraSDKClient) CreateVideoTask(ctx context.Context, account *Account, r
if size == "" {
size = "small"
}
+ videoCount := req.VideoCount
+ if videoCount <= 0 {
+ videoCount = 1
+ }
+ if videoCount > 3 {
+ videoCount = 3
+ }
// Remix 模式
if strings.TrimSpace(req.RemixTargetID) != "" {
+ if videoCount > 1 {
+ accountID := int64(0)
+ if account != nil {
+ accountID = account.ID
+ }
+ c.debugLogf("video_count_ignored_for_remix account_id=%d count=%d", accountID, videoCount)
+ }
styleID := "" // SDK ExtractStyle 可从 prompt 中提取
taskID, err := sdkClient.RemixVideo(ctx, token, sentinel, req.RemixTargetID, req.Prompt, orientation, nFrames, styleID)
if err != nil {
@@ -182,13 +208,60 @@ func (c *SoraSDKClient) CreateVideoTask(ctx context.Context, account *Account, r
}
// 普通视频(文生视频或图生视频)
- taskID, err := sdkClient.CreateVideoTaskWithOptions(ctx, token, sentinel, req.Prompt, orientation, nFrames, model, size, req.MediaID, "")
+ var taskID string
+ if videoCount <= 1 {
+ taskID, err = sdkClient.CreateVideoTaskWithOptions(ctx, token, sentinel, req.Prompt, orientation, nFrames, model, size, req.MediaID, "")
+ } else {
+ taskID, err = c.createVideoTaskWithVariants(ctx, account, token, sentinel, req.Prompt, orientation, nFrames, model, size, req.MediaID, videoCount)
+ }
if err != nil {
return "", c.wrapSDKError(err, account)
}
return taskID, nil
}
+func (c *SoraSDKClient) createVideoTaskWithVariants(
+ ctx context.Context,
+ account *Account,
+ accessToken string,
+ sentinelToken string,
+ prompt string,
+ orientation string,
+ nFrames int,
+ model string,
+ size string,
+ mediaID string,
+ videoCount int,
+) (string, error) {
+ inpaintItems := make([]any, 0, 1)
+ if strings.TrimSpace(mediaID) != "" {
+ inpaintItems = append(inpaintItems, map[string]any{
+ "kind": "upload",
+ "upload_id": mediaID,
+ })
+ }
+ payload := map[string]any{
+ "kind": "video",
+ "prompt": prompt,
+ "orientation": orientation,
+ "size": size,
+ "n_frames": nFrames,
+ "n_variants": videoCount,
+ "model": model,
+ "inpaint_items": inpaintItems,
+ "style_id": nil,
+ }
+ raw, err := c.doSoraBackendJSON(ctx, account, http.MethodPost, "/nf/create", accessToken, sentinelToken, payload)
+ if err != nil {
+ return "", err
+ }
+ taskID := strings.TrimSpace(gjson.GetBytes(raw, "id").String())
+ if taskID == "" {
+ return "", errors.New("create video task response missing id")
+ }
+ return taskID, nil
+}
+
func (c *SoraSDKClient) CreateStoryboardTask(ctx context.Context, account *Account, req SoraStoryboardRequest) (string, error) {
token, err := c.getAccessToken(ctx, account)
if err != nil {
@@ -512,7 +585,7 @@ func (c *SoraSDKClient) GetVideoTask(ctx context.Context, account *Account, task
}
// 任务不在 pending 中,查询 drafts 获取下载链接
- downloadURL, err := sdkClient.GetDownloadURL(ctx, token, taskID)
+ downloadURLs, err := c.getVideoTaskDownloadURLs(ctx, account, token, taskID)
if err != nil {
errMsg := err.Error()
if strings.Contains(errMsg, "内容违规") || strings.Contains(errMsg, "Content violates") {
@@ -528,13 +601,147 @@ func (c *SoraSDKClient) GetVideoTask(ctx context.Context, account *Account, task
Status: "processing",
}, nil
}
+ if len(downloadURLs) == 0 {
+ return &SoraVideoTaskStatus{
+ ID: taskID,
+ Status: "processing",
+ }, nil
+ }
return &SoraVideoTaskStatus{
ID: taskID,
Status: "completed",
- URLs: []string{downloadURL},
+ URLs: downloadURLs,
}, nil
}
+func (c *SoraSDKClient) getVideoTaskDownloadURLs(ctx context.Context, account *Account, accessToken, taskID string) ([]string, error) {
+ raw, err := c.doSoraBackendJSON(ctx, account, http.MethodGet, "/project_y/profile/drafts?limit=30", accessToken, "", nil)
+ if err != nil {
+ return nil, err
+ }
+ items := gjson.GetBytes(raw, "items")
+ if !items.Exists() || !items.IsArray() {
+ return nil, fmt.Errorf("drafts response missing items for task %s", taskID)
+ }
+ urlSet := make(map[string]struct{}, 4)
+ urls := make([]string, 0, 4)
+ items.ForEach(func(_, item gjson.Result) bool {
+ if strings.TrimSpace(item.Get("task_id").String()) != taskID {
+ return true
+ }
+ kind := strings.TrimSpace(item.Get("kind").String())
+ reason := strings.TrimSpace(item.Get("reason_str").String())
+ markdownReason := strings.TrimSpace(item.Get("markdown_reason_str").String())
+ if kind == "sora_content_violation" || reason != "" || markdownReason != "" {
+ if reason == "" {
+ reason = markdownReason
+ }
+ if reason == "" {
+ reason = "内容违规"
+ }
+ err = fmt.Errorf("内容违规: %s", reason)
+ return false
+ }
+ url := strings.TrimSpace(item.Get("downloadable_url").String())
+ if url == "" {
+ url = strings.TrimSpace(item.Get("url").String())
+ }
+ if url == "" {
+ return true
+ }
+ if _, exists := urlSet[url]; exists {
+ return true
+ }
+ urlSet[url] = struct{}{}
+ urls = append(urls, url)
+ return true
+ })
+ if err != nil {
+ return nil, err
+ }
+ if len(urls) > 0 {
+ return urls, nil
+ }
+
+ // 兼容旧 SDK 的兜底逻辑
+ sdkClient, sdkErr := c.getSDKClient(account)
+ if sdkErr != nil {
+ return nil, sdkErr
+ }
+ downloadURL, sdkErr := sdkClient.GetDownloadURL(ctx, accessToken, taskID)
+ if sdkErr != nil {
+ return nil, sdkErr
+ }
+ if strings.TrimSpace(downloadURL) == "" {
+ return nil, nil
+ }
+ return []string{downloadURL}, nil
+}
+
+func (c *SoraSDKClient) doSoraBackendJSON(
+ ctx context.Context,
+ account *Account,
+ method string,
+ path string,
+ accessToken string,
+ sentinelToken string,
+ payload map[string]any,
+) ([]byte, error) {
+ endpoint := "https://sora.chatgpt.com/backend" + path
+ var body io.Reader
+ if payload != nil {
+ raw, err := json.Marshal(payload)
+ if err != nil {
+ return nil, err
+ }
+ body = bytes.NewReader(raw)
+ }
+
+ req, err := http.NewRequestWithContext(ctx, method, endpoint, body)
+ if err != nil {
+ return nil, err
+ }
+ req.Header.Set("Authorization", "Bearer "+accessToken)
+ req.Header.Set("Accept", "application/json, text/plain, */*")
+ req.Header.Set("Origin", "https://sora.chatgpt.com")
+ req.Header.Set("Referer", "https://sora.chatgpt.com/")
+ req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
+ if payload != nil {
+ req.Header.Set("Content-Type", "application/json")
+ }
+ if strings.TrimSpace(sentinelToken) != "" {
+ req.Header.Set("openai-sentinel-token", sentinelToken)
+ }
+
+ proxyURL := c.resolveProxyURL(account)
+ accountID := int64(0)
+ accountConcurrency := 0
+ if account != nil {
+ accountID = account.ID
+ accountConcurrency = account.Concurrency
+ }
+
+ var resp *http.Response
+ if c.httpUpstream != nil {
+ resp, err = c.httpUpstream.Do(req, proxyURL, accountID, accountConcurrency)
+ } else {
+ resp, err = http.DefaultClient.Do(req)
+ }
+ if err != nil {
+ return nil, err
+ }
+ defer func() { _ = resp.Body.Close() }()
+
+ raw, err := io.ReadAll(io.LimitReader(resp.Body, 4<<20))
+ if err != nil {
+ return nil, err
+ }
+ if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
+ return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, truncateForLog(raw, 256))
+ }
+ return raw, nil
+}
+
// --- 内部方法 ---
// getSDKClient 获取或创建指定代理的 SDK 客户端实例
@@ -791,6 +998,17 @@ func (c *SoraSDKClient) wrapSDKError(err error, account *Account) error {
} else if strings.Contains(msg, "HTTP 404") {
statusCode = http.StatusNotFound
}
+ accountID := int64(0)
+ if account != nil {
+ accountID = account.ID
+ }
+ logger.LegacyPrintf(
+ "service.sora_sdk",
+ "[WrapSDKError] account_id=%d mapped_status=%d raw_err=%s",
+ accountID,
+ statusCode,
+ logredact.RedactText(msg),
+ )
return &SoraUpstreamError{
StatusCode: statusCode,
Message: msg,
diff --git a/backend/internal/service/sora_upstream_forwarder.go b/backend/internal/service/sora_upstream_forwarder.go
new file mode 100644
index 00000000..cdf9570b
--- /dev/null
+++ b/backend/internal/service/sora_upstream_forwarder.go
@@ -0,0 +1,149 @@
+package service
+
+import (
+ "bytes"
+ "context"
+ "fmt"
+ "io"
+ "net/http"
+ "strings"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/pkg/logger"
+ "github.com/gin-gonic/gin"
+)
+
+// forwardToUpstream 将请求 HTTP 透传到上游 Sora 服务(用于 apikey 类型账号)。
+// 上游地址为 account.GetBaseURL() + "/sora/v1/chat/completions",
+// 使用 account.GetCredential("api_key") 作为 Bearer Token。
+// 支持流式和非流式响应的直接透传。
+func (s *SoraGatewayService) forwardToUpstream(
+ ctx context.Context,
+ c *gin.Context,
+ account *Account,
+ body []byte,
+ clientStream bool,
+ startTime time.Time,
+) (*ForwardResult, error) {
+ apiKey := account.GetCredential("api_key")
+ if apiKey == "" {
+ s.writeSoraError(c, http.StatusBadGateway, "upstream_error", "Sora apikey account missing api_key credential", clientStream)
+ return nil, fmt.Errorf("sora apikey account %d missing api_key", account.ID)
+ }
+
+ baseURL := account.GetBaseURL()
+ if baseURL == "" {
+ s.writeSoraError(c, http.StatusBadGateway, "upstream_error", "Sora apikey account missing base_url", clientStream)
+ return nil, fmt.Errorf("sora apikey account %d missing base_url", account.ID)
+ }
+ // 校验 scheme 合法性(仅允许 http/https)
+ if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") {
+ s.writeSoraError(c, http.StatusBadGateway, "upstream_error", "Sora apikey base_url must start with http:// or https://", clientStream)
+ return nil, fmt.Errorf("sora apikey account %d invalid base_url scheme: %s", account.ID, baseURL)
+ }
+ upstreamURL := strings.TrimRight(baseURL, "/") + "/sora/v1/chat/completions"
+
+ // 构建上游请求
+ upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, upstreamURL, bytes.NewReader(body))
+ if err != nil {
+ s.writeSoraError(c, http.StatusInternalServerError, "api_error", "Failed to create upstream request", clientStream)
+ return nil, fmt.Errorf("create upstream request: %w", err)
+ }
+
+ upstreamReq.Header.Set("Content-Type", "application/json")
+ upstreamReq.Header.Set("Authorization", "Bearer "+apiKey)
+
+ // 透传客户端的部分请求头
+ for _, header := range []string{"Accept", "Accept-Encoding"} {
+ if v := c.GetHeader(header); v != "" {
+ upstreamReq.Header.Set(header, v)
+ }
+ }
+
+ logger.LegacyPrintf("service.sora", "[ForwardUpstream] account=%d url=%s", account.ID, upstreamURL)
+
+ // 获取代理 URL
+ proxyURL := ""
+ if account.ProxyID != nil && account.Proxy != nil {
+ proxyURL = account.Proxy.URL()
+ }
+
+ // 发送请求
+ resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
+ if err != nil {
+ s.writeSoraError(c, http.StatusBadGateway, "upstream_error", "Failed to connect to upstream Sora service", clientStream)
+ return nil, &UpstreamFailoverError{
+ StatusCode: http.StatusBadGateway,
+ }
+ }
+ defer func() {
+ _ = resp.Body.Close()
+ }()
+
+ // 错误响应处理
+ if resp.StatusCode >= 400 {
+ respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 64*1024))
+
+ if s.shouldFailoverUpstreamError(resp.StatusCode) {
+ return nil, &UpstreamFailoverError{
+ StatusCode: resp.StatusCode,
+ ResponseBody: respBody,
+ ResponseHeaders: resp.Header.Clone(),
+ }
+ }
+
+ // 非转移错误,直接透传给客户端
+ c.Status(resp.StatusCode)
+ for key, values := range resp.Header {
+ for _, v := range values {
+ c.Writer.Header().Add(key, v)
+ }
+ }
+ if _, err := c.Writer.Write(respBody); err != nil {
+ return nil, fmt.Errorf("write upstream error response: %w", err)
+ }
+ return nil, fmt.Errorf("upstream error: %d", resp.StatusCode)
+ }
+
+ // 成功响应 — 直接透传
+ c.Status(resp.StatusCode)
+ for key, values := range resp.Header {
+ lower := strings.ToLower(key)
+ // 透传内容相关头部
+ if lower == "content-type" || lower == "transfer-encoding" ||
+ lower == "cache-control" || lower == "x-request-id" {
+ for _, v := range values {
+ c.Writer.Header().Add(key, v)
+ }
+ }
+ }
+
+ // 流式复制响应体
+ if flusher, ok := c.Writer.(http.Flusher); ok && clientStream {
+ buf := make([]byte, 4096)
+ for {
+ n, readErr := resp.Body.Read(buf)
+ if n > 0 {
+ if _, err := c.Writer.Write(buf[:n]); err != nil {
+ return nil, fmt.Errorf("stream upstream response write: %w", err)
+ }
+ flusher.Flush()
+ }
+ if readErr != nil {
+ break
+ }
+ }
+ } else {
+ if _, err := io.Copy(c.Writer, resp.Body); err != nil {
+ return nil, fmt.Errorf("copy upstream response: %w", err)
+ }
+ }
+
+ duration := time.Since(startTime)
+ return &ForwardResult{
+ RequestID: resp.Header.Get("x-request-id"),
+ Model: "", // 由调用方填充
+ Stream: clientStream,
+ Duration: duration,
+ }, nil
+}
diff --git a/backend/internal/service/usage_cleanup.go b/backend/internal/service/usage_cleanup.go
index 7e3ffbb9..6e32f3c0 100644
--- a/backend/internal/service/usage_cleanup.go
+++ b/backend/internal/service/usage_cleanup.go
@@ -33,6 +33,7 @@ type UsageCleanupFilters struct {
AccountID *int64 `json:"account_id,omitempty"`
GroupID *int64 `json:"group_id,omitempty"`
Model *string `json:"model,omitempty"`
+ RequestType *int16 `json:"request_type,omitempty"`
Stream *bool `json:"stream,omitempty"`
BillingType *int8 `json:"billing_type,omitempty"`
}
diff --git a/backend/internal/service/usage_cleanup_service.go b/backend/internal/service/usage_cleanup_service.go
index ee795aa4..5600542e 100644
--- a/backend/internal/service/usage_cleanup_service.go
+++ b/backend/internal/service/usage_cleanup_service.go
@@ -68,6 +68,9 @@ func describeUsageCleanupFilters(filters UsageCleanupFilters) string {
if filters.Model != nil {
parts = append(parts, "model="+strings.TrimSpace(*filters.Model))
}
+ if filters.RequestType != nil {
+ parts = append(parts, "request_type="+RequestTypeFromInt16(*filters.RequestType).String())
+ }
if filters.Stream != nil {
parts = append(parts, fmt.Sprintf("stream=%t", *filters.Stream))
}
@@ -368,6 +371,16 @@ func sanitizeUsageCleanupFilters(filters *UsageCleanupFilters) {
filters.Model = &model
}
}
+ if filters.RequestType != nil {
+ requestType := RequestType(*filters.RequestType)
+ if !requestType.IsValid() {
+ filters.RequestType = nil
+ } else {
+ value := int16(requestType.Normalize())
+ filters.RequestType = &value
+ filters.Stream = nil
+ }
+ }
if filters.BillingType != nil && *filters.BillingType < 0 {
filters.BillingType = nil
}
diff --git a/backend/internal/service/usage_cleanup_service_test.go b/backend/internal/service/usage_cleanup_service_test.go
index 1f9f4776..0fdbfd47 100644
--- a/backend/internal/service/usage_cleanup_service_test.go
+++ b/backend/internal/service/usage_cleanup_service_test.go
@@ -257,6 +257,53 @@ func TestUsageCleanupServiceCreateTaskSanitizeFilters(t *testing.T) {
require.Equal(t, int64(9), task.CreatedBy)
}
+func TestSanitizeUsageCleanupFiltersRequestTypePriority(t *testing.T) {
+ requestType := int16(RequestTypeWSV2)
+ stream := false
+ model := " gpt-5 "
+ filters := UsageCleanupFilters{
+ Model: &model,
+ RequestType: &requestType,
+ Stream: &stream,
+ }
+
+ sanitizeUsageCleanupFilters(&filters)
+
+ require.NotNil(t, filters.RequestType)
+ require.Equal(t, int16(RequestTypeWSV2), *filters.RequestType)
+ require.Nil(t, filters.Stream)
+ require.NotNil(t, filters.Model)
+ require.Equal(t, "gpt-5", *filters.Model)
+}
+
+func TestSanitizeUsageCleanupFiltersInvalidRequestType(t *testing.T) {
+ requestType := int16(99)
+ stream := true
+ filters := UsageCleanupFilters{
+ RequestType: &requestType,
+ Stream: &stream,
+ }
+
+ sanitizeUsageCleanupFilters(&filters)
+
+ require.Nil(t, filters.RequestType)
+ require.NotNil(t, filters.Stream)
+ require.True(t, *filters.Stream)
+}
+
+func TestDescribeUsageCleanupFiltersIncludesRequestType(t *testing.T) {
+ start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
+ end := start.Add(24 * time.Hour)
+ requestType := int16(RequestTypeWSV2)
+ desc := describeUsageCleanupFilters(UsageCleanupFilters{
+ StartTime: start,
+ EndTime: end,
+ RequestType: &requestType,
+ })
+
+ require.Contains(t, desc, "request_type=ws_v2")
+}
+
func TestUsageCleanupServiceCreateTaskInvalidCreator(t *testing.T) {
repo := &cleanupRepoStub{}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
diff --git a/backend/internal/service/usage_log.go b/backend/internal/service/usage_log.go
index f9824183..c1a95541 100644
--- a/backend/internal/service/usage_log.go
+++ b/backend/internal/service/usage_log.go
@@ -1,12 +1,96 @@
package service
-import "time"
+import (
+ "fmt"
+ "strings"
+ "time"
+)
const (
BillingTypeBalance int8 = 0 // 钱包余额
BillingTypeSubscription int8 = 1 // 订阅套餐
)
+type RequestType int16
+
+const (
+ RequestTypeUnknown RequestType = 0
+ RequestTypeSync RequestType = 1
+ RequestTypeStream RequestType = 2
+ RequestTypeWSV2 RequestType = 3
+)
+
+func (t RequestType) IsValid() bool {
+ switch t {
+ case RequestTypeUnknown, RequestTypeSync, RequestTypeStream, RequestTypeWSV2:
+ return true
+ default:
+ return false
+ }
+}
+
+func (t RequestType) Normalize() RequestType {
+ if t.IsValid() {
+ return t
+ }
+ return RequestTypeUnknown
+}
+
+func (t RequestType) String() string {
+ switch t.Normalize() {
+ case RequestTypeSync:
+ return "sync"
+ case RequestTypeStream:
+ return "stream"
+ case RequestTypeWSV2:
+ return "ws_v2"
+ default:
+ return "unknown"
+ }
+}
+
+func RequestTypeFromInt16(v int16) RequestType {
+ return RequestType(v).Normalize()
+}
+
+func ParseUsageRequestType(value string) (RequestType, error) {
+ switch strings.ToLower(strings.TrimSpace(value)) {
+ case "unknown":
+ return RequestTypeUnknown, nil
+ case "sync":
+ return RequestTypeSync, nil
+ case "stream":
+ return RequestTypeStream, nil
+ case "ws_v2":
+ return RequestTypeWSV2, nil
+ default:
+ return RequestTypeUnknown, fmt.Errorf("invalid request_type, allowed values: unknown, sync, stream, ws_v2")
+ }
+}
+
+func RequestTypeFromLegacy(stream bool, openAIWSMode bool) RequestType {
+ if openAIWSMode {
+ return RequestTypeWSV2
+ }
+ if stream {
+ return RequestTypeStream
+ }
+ return RequestTypeSync
+}
+
+func ApplyLegacyRequestFields(requestType RequestType, fallbackStream bool, fallbackOpenAIWSMode bool) (stream bool, openAIWSMode bool) {
+ switch requestType.Normalize() {
+ case RequestTypeSync:
+ return false, false
+ case RequestTypeStream:
+ return true, false
+ case RequestTypeWSV2:
+ return true, true
+ default:
+ return fallbackStream, fallbackOpenAIWSMode
+ }
+}
+
type UsageLog struct {
ID int64
UserID int64
@@ -40,7 +124,9 @@ type UsageLog struct {
AccountRateMultiplier *float64
BillingType int8
+ RequestType RequestType
Stream bool
+ OpenAIWSMode bool
DurationMs *int
FirstTokenMs *int
UserAgent *string
@@ -66,3 +152,22 @@ type UsageLog struct {
func (u *UsageLog) TotalTokens() int {
return u.InputTokens + u.OutputTokens + u.CacheCreationTokens + u.CacheReadTokens
}
+
+func (u *UsageLog) EffectiveRequestType() RequestType {
+ if u == nil {
+ return RequestTypeUnknown
+ }
+ if normalized := u.RequestType.Normalize(); normalized != RequestTypeUnknown {
+ return normalized
+ }
+ return RequestTypeFromLegacy(u.Stream, u.OpenAIWSMode)
+}
+
+func (u *UsageLog) SyncRequestTypeAndLegacyFields() {
+ if u == nil {
+ return
+ }
+ requestType := u.EffectiveRequestType()
+ u.RequestType = requestType
+ u.Stream, u.OpenAIWSMode = ApplyLegacyRequestFields(requestType, u.Stream, u.OpenAIWSMode)
+}
diff --git a/backend/internal/service/usage_log_test.go b/backend/internal/service/usage_log_test.go
new file mode 100644
index 00000000..280237c2
--- /dev/null
+++ b/backend/internal/service/usage_log_test.go
@@ -0,0 +1,112 @@
+package service
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestParseUsageRequestType(t *testing.T) {
+ t.Parallel()
+
+ type testCase struct {
+ name string
+ input string
+ want RequestType
+ wantErr bool
+ }
+
+ cases := []testCase{
+ {name: "unknown", input: "unknown", want: RequestTypeUnknown},
+ {name: "sync", input: "sync", want: RequestTypeSync},
+ {name: "stream", input: "stream", want: RequestTypeStream},
+ {name: "ws_v2", input: "ws_v2", want: RequestTypeWSV2},
+ {name: "case_insensitive", input: "WS_V2", want: RequestTypeWSV2},
+ {name: "trim_spaces", input: " stream ", want: RequestTypeStream},
+ {name: "invalid", input: "xxx", wantErr: true},
+ }
+
+ for _, tc := range cases {
+ tc := tc
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+ got, err := ParseUsageRequestType(tc.input)
+ if tc.wantErr {
+ require.Error(t, err)
+ return
+ }
+ require.NoError(t, err)
+ require.Equal(t, tc.want, got)
+ })
+ }
+}
+
+func TestRequestTypeNormalizeAndString(t *testing.T) {
+ t.Parallel()
+
+ require.Equal(t, RequestTypeUnknown, RequestType(99).Normalize())
+ require.Equal(t, "unknown", RequestType(99).String())
+ require.Equal(t, "sync", RequestTypeSync.String())
+ require.Equal(t, "stream", RequestTypeStream.String())
+ require.Equal(t, "ws_v2", RequestTypeWSV2.String())
+}
+
+func TestRequestTypeFromLegacy(t *testing.T) {
+ t.Parallel()
+
+ require.Equal(t, RequestTypeWSV2, RequestTypeFromLegacy(false, true))
+ require.Equal(t, RequestTypeStream, RequestTypeFromLegacy(true, false))
+ require.Equal(t, RequestTypeSync, RequestTypeFromLegacy(false, false))
+}
+
+func TestApplyLegacyRequestFields(t *testing.T) {
+ t.Parallel()
+
+ stream, ws := ApplyLegacyRequestFields(RequestTypeSync, true, true)
+ require.False(t, stream)
+ require.False(t, ws)
+
+ stream, ws = ApplyLegacyRequestFields(RequestTypeStream, false, true)
+ require.True(t, stream)
+ require.False(t, ws)
+
+ stream, ws = ApplyLegacyRequestFields(RequestTypeWSV2, false, false)
+ require.True(t, stream)
+ require.True(t, ws)
+
+ stream, ws = ApplyLegacyRequestFields(RequestTypeUnknown, true, false)
+ require.True(t, stream)
+ require.False(t, ws)
+}
+
+func TestUsageLogSyncRequestTypeAndLegacyFields(t *testing.T) {
+ t.Parallel()
+
+ log := &UsageLog{RequestType: RequestTypeWSV2, Stream: false, OpenAIWSMode: false}
+ log.SyncRequestTypeAndLegacyFields()
+
+ require.Equal(t, RequestTypeWSV2, log.RequestType)
+ require.True(t, log.Stream)
+ require.True(t, log.OpenAIWSMode)
+}
+
+func TestUsageLogEffectiveRequestTypeFallback(t *testing.T) {
+ t.Parallel()
+
+ log := &UsageLog{RequestType: RequestTypeUnknown, Stream: true, OpenAIWSMode: true}
+ require.Equal(t, RequestTypeWSV2, log.EffectiveRequestType())
+}
+
+func TestUsageLogEffectiveRequestTypeNilReceiver(t *testing.T) {
+ t.Parallel()
+
+ var log *UsageLog
+ require.Equal(t, RequestTypeUnknown, log.EffectiveRequestType())
+}
+
+func TestUsageLogSyncRequestTypeAndLegacyFieldsNilReceiver(t *testing.T) {
+ t.Parallel()
+
+ var log *UsageLog
+ log.SyncRequestTypeAndLegacyFields()
+}
diff --git a/backend/internal/service/user.go b/backend/internal/service/user.go
index e56d83bf..487f12da 100644
--- a/backend/internal/service/user.go
+++ b/backend/internal/service/user.go
@@ -25,6 +25,10 @@ type User struct {
// map[groupID]rateMultiplier
GroupRates map[int64]float64
+ // Sora 存储配额
+ SoraStorageQuotaBytes int64 // 用户级 Sora 存储配额(0 表示使用分组或系统默认值)
+ SoraStorageUsedBytes int64 // Sora 存储已用量
+
// TOTP 双因素认证字段
TotpSecretEncrypted *string // AES-256-GCM 加密的 TOTP 密钥
TotpEnabled bool // 是否启用 TOTP
diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go
index 510e734e..b5553935 100644
--- a/backend/internal/service/user_service.go
+++ b/backend/internal/service/user_service.go
@@ -40,6 +40,8 @@ type UserRepository interface {
UpdateConcurrency(ctx context.Context, id int64, amount int) error
ExistsByEmail(ctx context.Context, email string) (bool, error)
RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error)
+ // AddGroupToAllowedGroups 将指定分组增量添加到用户的 allowed_groups(幂等,冲突忽略)
+ AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error
// TOTP 双因素认证
UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error
diff --git a/backend/internal/service/user_service_test.go b/backend/internal/service/user_service_test.go
index 0f355d70..5ba2b99e 100644
--- a/backend/internal/service/user_service_test.go
+++ b/backend/internal/service/user_service_test.go
@@ -21,12 +21,12 @@ type mockUserRepo struct {
updateBalanceFn func(ctx context.Context, id int64, amount float64) error
}
-func (m *mockUserRepo) Create(context.Context, *User) error { return nil }
-func (m *mockUserRepo) GetByID(context.Context, int64) (*User, error) { return &User{}, nil }
-func (m *mockUserRepo) GetByEmail(context.Context, string) (*User, error) { return &User{}, nil }
-func (m *mockUserRepo) GetFirstAdmin(context.Context) (*User, error) { return &User{}, nil }
-func (m *mockUserRepo) Update(context.Context, *User) error { return nil }
-func (m *mockUserRepo) Delete(context.Context, int64) error { return nil }
+func (m *mockUserRepo) Create(context.Context, *User) error { return nil }
+func (m *mockUserRepo) GetByID(context.Context, int64) (*User, error) { return &User{}, nil }
+func (m *mockUserRepo) GetByEmail(context.Context, string) (*User, error) { return &User{}, nil }
+func (m *mockUserRepo) GetFirstAdmin(context.Context) (*User, error) { return &User{}, nil }
+func (m *mockUserRepo) Update(context.Context, *User) error { return nil }
+func (m *mockUserRepo) Delete(context.Context, int64) error { return nil }
func (m *mockUserRepo) List(context.Context, pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) {
return nil, nil, nil
}
@@ -45,7 +45,8 @@ func (m *mockUserRepo) ExistsByEmail(context.Context, string) (bool, error) { re
func (m *mockUserRepo) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
return 0, nil
}
-func (m *mockUserRepo) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
+func (m *mockUserRepo) AddGroupToAllowedGroups(context.Context, int64, int64) error { return nil }
+func (m *mockUserRepo) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
func (m *mockUserRepo) EnableTotp(context.Context, int64) error { return nil }
func (m *mockUserRepo) DisableTotp(context.Context, int64) error { return nil }
@@ -56,8 +57,8 @@ type mockAuthCacheInvalidator struct {
mu sync.Mutex
}
-func (m *mockAuthCacheInvalidator) InvalidateAuthCacheByKey(context.Context, string) {}
-func (m *mockAuthCacheInvalidator) InvalidateAuthCacheByGroupID(context.Context, int64) {}
+func (m *mockAuthCacheInvalidator) InvalidateAuthCacheByKey(context.Context, string) {}
+func (m *mockAuthCacheInvalidator) InvalidateAuthCacheByGroupID(context.Context, int64) {}
func (m *mockAuthCacheInvalidator) InvalidateAuthCacheByUserID(_ context.Context, userID int64) {
m.mu.Lock()
defer m.mu.Unlock()
@@ -73,9 +74,9 @@ type mockBillingCache struct {
mu sync.Mutex
}
-func (m *mockBillingCache) GetUserBalance(context.Context, int64) (float64, error) { return 0, nil }
-func (m *mockBillingCache) SetUserBalance(context.Context, int64, float64) error { return nil }
-func (m *mockBillingCache) DeductUserBalance(context.Context, int64, float64) error { return nil }
+func (m *mockBillingCache) GetUserBalance(context.Context, int64) (float64, error) { return 0, nil }
+func (m *mockBillingCache) SetUserBalance(context.Context, int64, float64) error { return nil }
+func (m *mockBillingCache) DeductUserBalance(context.Context, int64, float64) error { return nil }
func (m *mockBillingCache) InvalidateUserBalance(_ context.Context, userID int64) error {
m.invalidateCallCount.Add(1)
m.mu.Lock()
diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go
index f04acc00..b0eccb71 100644
--- a/backend/internal/service/wire.go
+++ b/backend/internal/service/wire.go
@@ -284,6 +284,13 @@ func ProvideAPIKeyAuthCacheInvalidator(apiKeyService *APIKeyService) APIKeyAuthC
return apiKeyService
}
+// ProvideSettingService wires SettingService with group reader for default subscription validation.
+func ProvideSettingService(settingRepo SettingRepository, groupRepo GroupRepository, cfg *config.Config) *SettingService {
+ svc := NewSettingService(settingRepo, cfg)
+ svc.SetDefaultSubscriptionGroupReader(groupRepo)
+ return svc
+}
+
// ProviderSet is the Wire provider set for all services
var ProviderSet = wire.NewSet(
// Core services
@@ -326,7 +333,8 @@ var ProviderSet = wire.NewSet(
ProvideRateLimitService,
NewAccountUsageService,
NewAccountTestService,
- NewSettingService,
+ ProvideSettingService,
+ NewDataManagementService,
ProvideOpsSystemLogSink,
NewOpsService,
ProvideOpsMetricsCollector,
@@ -338,6 +346,7 @@ var ProviderSet = wire.NewSet(
ProvideEmailQueueService,
NewTurnstileService,
NewSubscriptionService,
+ wire.Bind(new(DefaultSubscriptionAssigner), new(*SubscriptionService)),
ProvideConcurrencyService,
NewUsageRecordWorkerPool,
ProvideSchedulerSnapshotService,
diff --git a/backend/internal/testutil/stubs.go b/backend/internal/testutil/stubs.go
index 3569db17..217a5f56 100644
--- a/backend/internal/testutil/stubs.go
+++ b/backend/internal/testutil/stubs.go
@@ -66,6 +66,13 @@ func (c StubConcurrencyCache) GetUsersLoadBatch(_ context.Context, users []servi
}
return result, nil
}
+func (c StubConcurrencyCache) GetAccountConcurrencyBatch(_ context.Context, accountIDs []int64) (map[int64]int, error) {
+ result := make(map[int64]int, len(accountIDs))
+ for _, id := range accountIDs {
+ result[id] = 0
+ }
+ return result, nil
+}
func (c StubConcurrencyCache) CleanupExpiredAccountSlots(_ context.Context, _ int64) error {
return nil
}
diff --git a/backend/internal/util/logredact/redact.go b/backend/internal/util/logredact/redact.go
index 492d875c..9249b761 100644
--- a/backend/internal/util/logredact/redact.go
+++ b/backend/internal/util/logredact/redact.go
@@ -3,7 +3,9 @@ package logredact
import (
"encoding/json"
"regexp"
+ "sort"
"strings"
+ "sync"
)
// maxRedactDepth 限制递归深度以防止栈溢出
@@ -31,9 +33,18 @@ var defaultSensitiveKeyList = []string{
"password",
}
+type textRedactPatterns struct {
+ reJSONLike *regexp.Regexp
+ reQueryLike *regexp.Regexp
+ rePlain *regexp.Regexp
+}
+
var (
reGOCSPX = regexp.MustCompile(`GOCSPX-[0-9A-Za-z_-]{24,}`)
reAIza = regexp.MustCompile(`AIza[0-9A-Za-z_-]{35}`)
+
+ defaultTextRedactPatterns = compileTextRedactPatterns(nil)
+ extraTextPatternCache sync.Map // map[string]*textRedactPatterns
)
func RedactMap(input map[string]any, extraKeys ...string) map[string]any {
@@ -83,23 +94,71 @@ func RedactText(input string, extraKeys ...string) string {
return RedactJSON(raw, extraKeys...)
}
- keyAlt := buildKeyAlternation(extraKeys)
- // JSON-like: "access_token":"..."
- reJSONLike := regexp.MustCompile(`(?i)("(?:` + keyAlt + `)"\s*:\s*")([^"]*)(")`)
- // Query-like: access_token=...
- reQueryLike := regexp.MustCompile(`(?i)\b((?:` + keyAlt + `))=([^&\s]+)`)
- // Plain: access_token: ... / access_token = ...
- rePlain := regexp.MustCompile(`(?i)\b((?:` + keyAlt + `))\b(\s*[:=]\s*)([^,\s]+)`)
+ patterns := getTextRedactPatterns(extraKeys)
out := input
out = reGOCSPX.ReplaceAllString(out, "GOCSPX-***")
out = reAIza.ReplaceAllString(out, "AIza***")
- out = reJSONLike.ReplaceAllString(out, `$1***$3`)
- out = reQueryLike.ReplaceAllString(out, `$1=***`)
- out = rePlain.ReplaceAllString(out, `$1$2***`)
+ out = patterns.reJSONLike.ReplaceAllString(out, `$1***$3`)
+ out = patterns.reQueryLike.ReplaceAllString(out, `$1=***`)
+ out = patterns.rePlain.ReplaceAllString(out, `$1$2***`)
return out
}
+func compileTextRedactPatterns(extraKeys []string) *textRedactPatterns {
+ keyAlt := buildKeyAlternation(extraKeys)
+ return &textRedactPatterns{
+ // JSON-like: "access_token":"..."
+ reJSONLike: regexp.MustCompile(`(?i)("(?:` + keyAlt + `)"\s*:\s*")([^"]*)(")`),
+ // Query-like: access_token=...
+ reQueryLike: regexp.MustCompile(`(?i)\b((?:` + keyAlt + `))=([^&\s]+)`),
+ // Plain: access_token: ... / access_token = ...
+ rePlain: regexp.MustCompile(`(?i)\b((?:` + keyAlt + `))\b(\s*[:=]\s*)([^,\s]+)`),
+ }
+}
+
+func getTextRedactPatterns(extraKeys []string) *textRedactPatterns {
+ normalizedExtraKeys := normalizeAndSortExtraKeys(extraKeys)
+ if len(normalizedExtraKeys) == 0 {
+ return defaultTextRedactPatterns
+ }
+
+ cacheKey := strings.Join(normalizedExtraKeys, ",")
+ if cached, ok := extraTextPatternCache.Load(cacheKey); ok {
+ if patterns, ok := cached.(*textRedactPatterns); ok {
+ return patterns
+ }
+ }
+
+ compiled := compileTextRedactPatterns(normalizedExtraKeys)
+ actual, _ := extraTextPatternCache.LoadOrStore(cacheKey, compiled)
+ if patterns, ok := actual.(*textRedactPatterns); ok {
+ return patterns
+ }
+ return compiled
+}
+
+func normalizeAndSortExtraKeys(extraKeys []string) []string {
+ if len(extraKeys) == 0 {
+ return nil
+ }
+ seen := make(map[string]struct{}, len(extraKeys))
+ keys := make([]string, 0, len(extraKeys))
+ for _, key := range extraKeys {
+ normalized := normalizeKey(key)
+ if normalized == "" {
+ continue
+ }
+ if _, ok := seen[normalized]; ok {
+ continue
+ }
+ seen[normalized] = struct{}{}
+ keys = append(keys, normalized)
+ }
+ sort.Strings(keys)
+ return keys
+}
+
func buildKeyAlternation(extraKeys []string) string {
seen := make(map[string]struct{}, len(defaultSensitiveKeyList)+len(extraKeys))
keys := make([]string, 0, len(defaultSensitiveKeyList)+len(extraKeys))
diff --git a/backend/internal/util/logredact/redact_test.go b/backend/internal/util/logredact/redact_test.go
index 64a7b3cf..266db69d 100644
--- a/backend/internal/util/logredact/redact_test.go
+++ b/backend/internal/util/logredact/redact_test.go
@@ -37,3 +37,48 @@ func TestRedactText_GOCSPX(t *testing.T) {
t.Fatalf("expected key redacted, got %q", out)
}
}
+
+func TestRedactText_ExtraKeyCacheUsesNormalizedSortedKey(t *testing.T) {
+ clearExtraTextPatternCache()
+
+ out1 := RedactText("custom_secret=abc", "Custom_Secret", " custom_secret ")
+ out2 := RedactText("custom_secret=xyz", "custom_secret")
+ if !strings.Contains(out1, "custom_secret=***") {
+ t.Fatalf("expected custom key redacted in first call, got %q", out1)
+ }
+ if !strings.Contains(out2, "custom_secret=***") {
+ t.Fatalf("expected custom key redacted in second call, got %q", out2)
+ }
+
+ if got := countExtraTextPatternCacheEntries(); got != 1 {
+ t.Fatalf("expected 1 cached pattern set, got %d", got)
+ }
+}
+
+func TestRedactText_DefaultPathDoesNotUseExtraCache(t *testing.T) {
+ clearExtraTextPatternCache()
+
+ out := RedactText("access_token=abc")
+ if !strings.Contains(out, "access_token=***") {
+ t.Fatalf("expected default key redacted, got %q", out)
+ }
+ if got := countExtraTextPatternCacheEntries(); got != 0 {
+ t.Fatalf("expected extra cache to remain empty, got %d", got)
+ }
+}
+
+func clearExtraTextPatternCache() {
+ extraTextPatternCache.Range(func(key, value any) bool {
+ extraTextPatternCache.Delete(key)
+ return true
+ })
+}
+
+func countExtraTextPatternCacheEntries() int {
+ count := 0
+ extraTextPatternCache.Range(func(key, value any) bool {
+ count++
+ return true
+ })
+ return count
+}
diff --git a/backend/internal/util/responseheaders/responseheaders.go b/backend/internal/util/responseheaders/responseheaders.go
index 86c3f624..7f7baca6 100644
--- a/backend/internal/util/responseheaders/responseheaders.go
+++ b/backend/internal/util/responseheaders/responseheaders.go
@@ -41,7 +41,14 @@ var hopByHopHeaders = map[string]struct{}{
"connection": {},
}
-func FilterHeaders(src http.Header, cfg config.ResponseHeaderConfig) http.Header {
+type CompiledHeaderFilter struct {
+ allowed map[string]struct{}
+ forceRemove map[string]struct{}
+}
+
+var defaultCompiledHeaderFilter = CompileHeaderFilter(config.ResponseHeaderConfig{})
+
+func CompileHeaderFilter(cfg config.ResponseHeaderConfig) *CompiledHeaderFilter {
allowed := make(map[string]struct{}, len(defaultAllowed)+len(cfg.AdditionalAllowed))
for key := range defaultAllowed {
allowed[key] = struct{}{}
@@ -69,13 +76,24 @@ func FilterHeaders(src http.Header, cfg config.ResponseHeaderConfig) http.Header
}
}
+ return &CompiledHeaderFilter{
+ allowed: allowed,
+ forceRemove: forceRemove,
+ }
+}
+
+func FilterHeaders(src http.Header, filter *CompiledHeaderFilter) http.Header {
+ if filter == nil {
+ filter = defaultCompiledHeaderFilter
+ }
+
filtered := make(http.Header, len(src))
for key, values := range src {
lower := strings.ToLower(key)
- if _, blocked := forceRemove[lower]; blocked {
+ if _, blocked := filter.forceRemove[lower]; blocked {
continue
}
- if _, ok := allowed[lower]; !ok {
+ if _, ok := filter.allowed[lower]; !ok {
continue
}
// 跳过 hop-by-hop 头部,这些由 HTTP 库自动处理
@@ -89,8 +107,8 @@ func FilterHeaders(src http.Header, cfg config.ResponseHeaderConfig) http.Header
return filtered
}
-func WriteFilteredHeaders(dst http.Header, src http.Header, cfg config.ResponseHeaderConfig) {
- filtered := FilterHeaders(src, cfg)
+func WriteFilteredHeaders(dst http.Header, src http.Header, filter *CompiledHeaderFilter) {
+ filtered := FilterHeaders(src, filter)
for key, values := range filtered {
for _, value := range values {
dst.Add(key, value)
diff --git a/backend/internal/util/responseheaders/responseheaders_test.go b/backend/internal/util/responseheaders/responseheaders_test.go
index f7343267..d817559e 100644
--- a/backend/internal/util/responseheaders/responseheaders_test.go
+++ b/backend/internal/util/responseheaders/responseheaders_test.go
@@ -20,7 +20,7 @@ func TestFilterHeadersDisabledUsesDefaultAllowlist(t *testing.T) {
ForceRemove: []string{"x-request-id"},
}
- filtered := FilterHeaders(src, cfg)
+ filtered := FilterHeaders(src, CompileHeaderFilter(cfg))
if filtered.Get("Content-Type") != "application/json" {
t.Fatalf("expected Content-Type passthrough, got %q", filtered.Get("Content-Type"))
}
@@ -51,7 +51,7 @@ func TestFilterHeadersEnabledUsesAllowlist(t *testing.T) {
ForceRemove: []string{"x-remove"},
}
- filtered := FilterHeaders(src, cfg)
+ filtered := FilterHeaders(src, CompileHeaderFilter(cfg))
if filtered.Get("Content-Type") != "application/json" {
t.Fatalf("expected Content-Type allowed, got %q", filtered.Get("Content-Type"))
}
diff --git a/backend/migrations/060_add_gemini31_flash_image_to_model_mapping.sql b/backend/migrations/060_add_gemini31_flash_image_to_model_mapping.sql
new file mode 100644
index 00000000..d0ed5d6d
--- /dev/null
+++ b/backend/migrations/060_add_gemini31_flash_image_to_model_mapping.sql
@@ -0,0 +1,46 @@
+-- Add gemini-3.1-flash-image and gemini-3.1-flash-image-preview to model_mapping
+--
+-- Background:
+-- Antigravity now supports gemini-3.1-flash-image as the latest image generation model,
+-- replacing the previous gemini-3-pro-image.
+--
+-- Strategy:
+-- Directly overwrite the entire model_mapping with updated mappings
+-- This ensures consistency with DefaultAntigravityModelMapping in constants.go
+
+UPDATE accounts
+SET credentials = jsonb_set(
+ credentials,
+ '{model_mapping}',
+ '{
+ "claude-opus-4-6-thinking": "claude-opus-4-6-thinking",
+ "claude-opus-4-6": "claude-opus-4-6-thinking",
+ "claude-opus-4-5-thinking": "claude-opus-4-6-thinking",
+ "claude-opus-4-5-20251101": "claude-opus-4-6-thinking",
+ "claude-sonnet-4-6": "claude-sonnet-4-6",
+ "claude-sonnet-4-5": "claude-sonnet-4-5",
+ "claude-sonnet-4-5-thinking": "claude-sonnet-4-5-thinking",
+ "claude-sonnet-4-5-20250929": "claude-sonnet-4-5",
+ "claude-haiku-4-5": "claude-sonnet-4-5",
+ "claude-haiku-4-5-20251001": "claude-sonnet-4-5",
+ "gemini-2.5-flash": "gemini-2.5-flash",
+ "gemini-2.5-flash-lite": "gemini-2.5-flash-lite",
+ "gemini-2.5-flash-thinking": "gemini-2.5-flash-thinking",
+ "gemini-2.5-pro": "gemini-2.5-pro",
+ "gemini-3-flash": "gemini-3-flash",
+ "gemini-3-pro-high": "gemini-3-pro-high",
+ "gemini-3-pro-low": "gemini-3-pro-low",
+ "gemini-3-flash-preview": "gemini-3-flash",
+ "gemini-3-pro-preview": "gemini-3-pro-high",
+ "gemini-3.1-pro-high": "gemini-3.1-pro-high",
+ "gemini-3.1-pro-low": "gemini-3.1-pro-low",
+ "gemini-3.1-pro-preview": "gemini-3.1-pro-high",
+ "gemini-3.1-flash-image": "gemini-3.1-flash-image",
+ "gemini-3.1-flash-image-preview": "gemini-3.1-flash-image",
+ "gpt-oss-120b-medium": "gpt-oss-120b-medium",
+ "tab_flash_lite_preview": "tab_flash_lite_preview"
+ }'::jsonb
+)
+WHERE platform = 'antigravity'
+ AND deleted_at IS NULL
+ AND credentials->'model_mapping' IS NOT NULL;
\ No newline at end of file
diff --git a/backend/migrations/060_add_usage_log_openai_ws_mode.sql b/backend/migrations/060_add_usage_log_openai_ws_mode.sql
new file mode 100644
index 00000000..b7d22414
--- /dev/null
+++ b/backend/migrations/060_add_usage_log_openai_ws_mode.sql
@@ -0,0 +1,2 @@
+-- Add openai_ws_mode flag to usage_logs to persist exact OpenAI WS transport type.
+ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS openai_ws_mode BOOLEAN NOT NULL DEFAULT FALSE;
diff --git a/backend/migrations/061_add_usage_log_request_type.sql b/backend/migrations/061_add_usage_log_request_type.sql
new file mode 100644
index 00000000..68a33d51
--- /dev/null
+++ b/backend/migrations/061_add_usage_log_request_type.sql
@@ -0,0 +1,29 @@
+-- Add request_type enum for usage_logs while keeping legacy stream/openai_ws_mode compatibility.
+ALTER TABLE usage_logs
+ ADD COLUMN IF NOT EXISTS request_type SMALLINT NOT NULL DEFAULT 0;
+
+DO $$
+BEGIN
+ IF NOT EXISTS (
+ SELECT 1
+ FROM pg_constraint
+ WHERE conname = 'usage_logs_request_type_check'
+ ) THEN
+ ALTER TABLE usage_logs
+ ADD CONSTRAINT usage_logs_request_type_check
+ CHECK (request_type IN (0, 1, 2, 3));
+ END IF;
+END
+$$;
+
+CREATE INDEX IF NOT EXISTS idx_usage_logs_request_type_created_at
+ ON usage_logs (request_type, created_at);
+
+-- Backfill from legacy fields. openai_ws_mode has higher priority than stream.
+UPDATE usage_logs
+SET request_type = CASE
+ WHEN openai_ws_mode = TRUE THEN 3
+ WHEN stream = TRUE THEN 2
+ ELSE 1
+END
+WHERE request_type = 0;
diff --git a/backend/migrations/062_add_scheduler_and_usage_composite_indexes_notx.sql b/backend/migrations/062_add_scheduler_and_usage_composite_indexes_notx.sql
new file mode 100644
index 00000000..c6139338
--- /dev/null
+++ b/backend/migrations/062_add_scheduler_and_usage_composite_indexes_notx.sql
@@ -0,0 +1,15 @@
+CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_accounts_schedulable_hot
+ ON accounts (platform, priority)
+ WHERE deleted_at IS NULL AND status = 'active' AND schedulable = true;
+
+CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_accounts_active_schedulable
+ ON accounts (priority, status)
+ WHERE deleted_at IS NULL AND schedulable = true;
+
+CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_user_subscriptions_user_status_expires_active
+ ON user_subscriptions (user_id, status, expires_at)
+ WHERE deleted_at IS NULL;
+
+CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_usage_logs_group_created_at_not_null
+ ON usage_logs (group_id, created_at)
+ WHERE group_id IS NOT NULL;
diff --git a/backend/migrations/063_add_sora_client_tables.sql b/backend/migrations/063_add_sora_client_tables.sql
new file mode 100644
index 00000000..69197f10
--- /dev/null
+++ b/backend/migrations/063_add_sora_client_tables.sql
@@ -0,0 +1,56 @@
+-- Migration: 063_add_sora_client_tables
+-- Sora 客户端功能所需的数据库变更:
+-- 1. 新增 sora_generations 表:记录 Sora 客户端 UI 的生成历史
+-- 2. users 表新增存储配额字段
+-- 3. groups 表新增存储配额字段
+
+-- ============================================================
+-- 1. sora_generations 表(生成记录)
+-- ============================================================
+CREATE TABLE IF NOT EXISTS sora_generations (
+ id BIGSERIAL PRIMARY KEY,
+ user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
+ api_key_id BIGINT,
+
+ -- 生成参数
+ model VARCHAR(64) NOT NULL,
+ prompt TEXT NOT NULL DEFAULT '',
+ media_type VARCHAR(16) NOT NULL DEFAULT 'video', -- video / image
+
+ -- 结果
+ status VARCHAR(16) NOT NULL DEFAULT 'pending', -- pending / generating / completed / failed / cancelled
+ media_url TEXT NOT NULL DEFAULT '',
+ media_urls JSONB, -- 多图时的 URL 数组
+ file_size_bytes BIGINT NOT NULL DEFAULT 0,
+ storage_type VARCHAR(16) NOT NULL DEFAULT 'none', -- s3 / local / upstream / none
+ s3_object_keys JSONB, -- S3 object key 数组
+
+ -- 上游信息
+ upstream_task_id VARCHAR(128) NOT NULL DEFAULT '',
+ error_message TEXT NOT NULL DEFAULT '',
+
+ -- 时间
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ completed_at TIMESTAMPTZ
+);
+
+-- 按用户+时间查询(作品库列表、历史记录)
+CREATE INDEX IF NOT EXISTS idx_sora_gen_user_created
+ ON sora_generations(user_id, created_at DESC);
+
+-- 按用户+状态查询(恢复进行中任务)
+CREATE INDEX IF NOT EXISTS idx_sora_gen_user_status
+ ON sora_generations(user_id, status);
+
+-- ============================================================
+-- 2. users 表新增 Sora 存储配额字段
+-- ============================================================
+ALTER TABLE users
+ ADD COLUMN IF NOT EXISTS sora_storage_quota_bytes BIGINT NOT NULL DEFAULT 0,
+ ADD COLUMN IF NOT EXISTS sora_storage_used_bytes BIGINT NOT NULL DEFAULT 0;
+
+-- ============================================================
+-- 3. groups 表新增 Sora 存储配额字段
+-- ============================================================
+ALTER TABLE groups
+ ADD COLUMN IF NOT EXISTS sora_storage_quota_bytes BIGINT NOT NULL DEFAULT 0;
diff --git a/backend/migrations/README.md b/backend/migrations/README.md
index 3fe328e6..47f6fa35 100644
--- a/backend/migrations/README.md
+++ b/backend/migrations/README.md
@@ -12,6 +12,26 @@ Format: `NNN_description.sql`
Example: `017_add_gemini_tier_id.sql`
+### `_notx.sql` 命名与执行语义(并发索引专用)
+
+当迁移包含 `CREATE INDEX CONCURRENTLY` 或 `DROP INDEX CONCURRENTLY` 时,必须使用 `_notx.sql` 后缀,例如:
+
+- `062_add_accounts_priority_indexes_notx.sql`
+- `063_drop_legacy_indexes_notx.sql`
+
+运行规则:
+
+1. `*.sql`(不带 `_notx`)按事务执行。
+2. `*_notx.sql` 按非事务执行,不会包裹在 `BEGIN/COMMIT` 中。
+3. `*_notx.sql` 仅允许并发索引语句,不允许混入事务控制语句或其他 DDL/DML。
+
+幂等要求(必须):
+
+- 创建索引:`CREATE INDEX CONCURRENTLY IF NOT EXISTS ...`
+- 删除索引:`DROP INDEX CONCURRENTLY IF EXISTS ...`
+
+这样可以保证灾备重放、重复执行时不会因对象已存在/不存在而失败。
+
## Migration File Structure
```sql
diff --git a/deploy/.env.example b/deploy/.env.example
index 290f918a..9f2ff13e 100644
--- a/deploy/.env.example
+++ b/deploy/.env.example
@@ -66,11 +66,15 @@ LOG_SAMPLING_INITIAL=100
# 之后每 N 条保留 1 条
LOG_SAMPLING_THEREAFTER=100
-# Global max request body size in bytes (default: 100MB)
-# 全局最大请求体大小(字节,默认 100MB)
+# Global max request body size in bytes (default: 256MB)
+# 全局最大请求体大小(字节,默认 256MB)
# Applies to all requests, especially important for h2c first request memory protection
# 适用于所有请求,对 h2c 第一请求的内存保护尤为重要
-SERVER_MAX_REQUEST_BODY_SIZE=104857600
+SERVER_MAX_REQUEST_BODY_SIZE=268435456
+
+# Gateway max request body size in bytes (default: 256MB)
+# 网关请求体最大字节数(默认 256MB)
+GATEWAY_MAX_BODY_SIZE=268435456
# Enable HTTP/2 Cleartext (h2c) for client connections
# 启用 HTTP/2 Cleartext (h2c) 客户端连接
diff --git a/deploy/DATAMANAGEMENTD_CN.md b/deploy/DATAMANAGEMENTD_CN.md
new file mode 100644
index 00000000..774f03ae
--- /dev/null
+++ b/deploy/DATAMANAGEMENTD_CN.md
@@ -0,0 +1,78 @@
+# datamanagementd 部署说明(数据管理)
+
+本文说明如何在宿主机部署 `datamanagementd`,并与主进程联动开启“数据管理”功能。
+
+## 1. 关键约束
+
+- 主进程固定探测路径:`/tmp/sub2api-datamanagement.sock`
+- 仅当该 Unix Socket 可连通且 `Health` 成功时,后台“数据管理”才会启用
+- `datamanagementd` 使用 SQLite 持久化元数据,不依赖主库
+
+## 2. 宿主机构建与运行
+
+```bash
+cd /opt/sub2api-src/datamanagement
+go build -o /opt/sub2api/datamanagementd ./cmd/datamanagementd
+
+mkdir -p /var/lib/sub2api/datamanagement
+chown -R sub2api:sub2api /var/lib/sub2api/datamanagement
+```
+
+手动启动示例:
+
+```bash
+/opt/sub2api/datamanagementd \
+ -socket-path /tmp/sub2api-datamanagement.sock \
+ -sqlite-path /var/lib/sub2api/datamanagement/datamanagementd.db \
+ -version 1.0.0
+```
+
+## 3. systemd 托管(推荐)
+
+仓库已提供示例服务文件:`deploy/sub2api-datamanagementd.service`
+
+```bash
+sudo cp deploy/sub2api-datamanagementd.service /etc/systemd/system/
+sudo systemctl daemon-reload
+sudo systemctl enable --now sub2api-datamanagementd
+sudo systemctl status sub2api-datamanagementd
+```
+
+查看日志:
+
+```bash
+sudo journalctl -u sub2api-datamanagementd -f
+```
+
+也可以使用一键安装脚本(自动安装二进制 + 注册 systemd):
+
+```bash
+# 方式一:使用现成二进制
+sudo ./deploy/install-datamanagementd.sh --binary /path/to/datamanagementd
+
+# 方式二:从源码构建后安装
+sudo ./deploy/install-datamanagementd.sh --source /path/to/sub2api
+```
+
+## 4. Docker 部署联动
+
+若 `sub2api` 运行在 Docker 容器中,需要将宿主机 Socket 挂载到容器同路径:
+
+```yaml
+services:
+ sub2api:
+ volumes:
+ - /tmp/sub2api-datamanagement.sock:/tmp/sub2api-datamanagement.sock
+```
+
+建议在 `docker-compose.override.yml` 中维护该挂载,避免覆盖主 compose 文件。
+
+## 5. 依赖检查
+
+`datamanagementd` 执行备份时依赖以下工具:
+
+- `pg_dump`
+- `redis-cli`
+- `docker`(仅 `source_mode=docker_exec` 时)
+
+缺失依赖会导致对应任务失败,并在任务详情中体现错误信息。
diff --git a/deploy/README.md b/deploy/README.md
index 3292e81a..807bf510 100644
--- a/deploy/README.md
+++ b/deploy/README.md
@@ -19,7 +19,10 @@ This directory contains files for deploying Sub2API on Linux servers.
| `.env.example` | Docker environment variables template |
| `DOCKER.md` | Docker Hub documentation |
| `install.sh` | One-click binary installation script |
+| `install-datamanagementd.sh` | datamanagementd 一键安装脚本 |
| `sub2api.service` | Systemd service unit file |
+| `sub2api-datamanagementd.service` | datamanagementd systemd service unit file |
+| `DATAMANAGEMENTD_CN.md` | datamanagementd 部署与联动说明(中文) |
| `config.example.yaml` | Example configuration file |
---
@@ -145,6 +148,14 @@ SELECT
(SELECT COUNT(*) FROM user_allowed_groups) AS new_pair_count;
```
+### datamanagementd(数据管理)联动
+
+如需启用管理后台“数据管理”功能,请额外部署宿主机 `datamanagementd`:
+
+- 主进程固定探测 `/tmp/sub2api-datamanagement.sock`
+- Docker 场景下需把宿主机 Socket 挂载到容器内同路径
+- 详细步骤见:`deploy/DATAMANAGEMENTD_CN.md`
+
### Commands
For **local directory version** (docker-compose.local.yml):
@@ -575,7 +586,7 @@ gateway:
name: "Profile 2"
cipher_suites: [4866, 4867, 4865, 49199, 49195, 49200, 49196]
curves: [29, 23, 24]
- point_formats: [0]
+ point_formats: 0
# Another custom profile
profile_3:
diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml
index 46a91ad6..faa85854 100644
--- a/deploy/config.example.yaml
+++ b/deploy/config.example.yaml
@@ -27,11 +27,11 @@ server:
# Trusted proxies for X-Forwarded-For parsing (CIDR/IP). Empty disables trusted proxies.
# 信任的代理地址(CIDR/IP 格式),用于解析 X-Forwarded-For 头。留空则禁用代理信任。
trusted_proxies: []
- # Global max request body size in bytes (default: 100MB)
- # 全局最大请求体大小(字节,默认 100MB)
+ # Global max request body size in bytes (default: 256MB)
+ # 全局最大请求体大小(字节,默认 256MB)
# Applies to all requests, especially important for h2c first request memory protection
# 适用于所有请求,对 h2c 第一请求的内存保护尤为重要
- max_request_body_size: 104857600
+ max_request_body_size: 268435456
# HTTP/2 Cleartext (h2c) configuration
# HTTP/2 Cleartext (h2c) 配置
h2c:
@@ -143,9 +143,9 @@ gateway:
# Timeout for waiting upstream response headers (seconds)
# 等待上游响应头超时时间(秒)
response_header_timeout: 600
- # Max request body size in bytes (default: 100MB)
- # 请求体最大字节数(默认 100MB)
- max_body_size: 104857600
+ # Max request body size in bytes (default: 256MB)
+ # 请求体最大字节数(默认 256MB)
+ max_body_size: 268435456
# Max bytes to read for non-stream upstream responses (default: 8MB)
# 非流式上游响应体读取上限(默认 8MB)
upstream_response_read_max_bytes: 8388608
@@ -199,6 +199,83 @@ gateway:
# OpenAI 透传模式是否放行客户端超时头(如 x-stainless-timeout)
# 默认 false:过滤超时头,降低上游提前断流风险。
openai_passthrough_allow_timeout_headers: false
+ # OpenAI Responses WebSocket 配置(默认开启,可按需回滚到 HTTP)
+ openai_ws:
+ # 新版 WS mode 路由(默认关闭)。关闭时保持当前 legacy 实现行为。
+ mode_router_v2_enabled: false
+ # ingress 默认模式:off|shared|dedicated(仅 mode_router_v2_enabled=true 生效)
+ ingress_mode_default: shared
+ # 全局总开关,默认 true;关闭时所有请求保持原有 HTTP/SSE 路由
+ enabled: true
+ # 按账号类型细分开关
+ oauth_enabled: true
+ apikey_enabled: true
+ # 全局强制 HTTP(紧急回滚开关)
+ force_http: false
+ # 允许在 WSv2 下按策略恢复 store=true(默认 false)
+ allow_store_recovery: false
+ # ingress 模式收到 previous_response_not_found 时,自动去掉 previous_response_id 重试一次(默认 true)
+ ingress_previous_response_recovery_enabled: true
+ # store=false 且无可复用会话连接时的策略:
+ # strict=强制新建连接(隔离优先),adaptive=仅在高风险失败后强制新建,off=尽量复用(性能优先)
+ store_disabled_conn_mode: strict
+ # store=false 且无可复用会话连接时,是否强制新建连接(默认 true,优先会话隔离)
+ # 兼容旧配置:仅在 store_disabled_conn_mode 未配置时生效
+ store_disabled_force_new_conn: true
+ # 是否启用 WSv2 generate=false 预热(默认 false)
+ prewarm_generate_enabled: false
+ # 协议 feature 开关,v2 优先于 v1
+ responses_websockets: false
+ responses_websockets_v2: true
+ # 连接池参数(按账号池化复用)
+ max_conns_per_account: 128
+ min_idle_per_account: 4
+ max_idle_per_account: 12
+ # 是否按账号并发动态计算连接池上限:
+ # effective_max_conns = min(max_conns_per_account, ceil(account.concurrency * factor))
+ dynamic_max_conns_by_account_concurrency_enabled: true
+ # 按账号类型分别设置系数(OAuth / API Key)
+ oauth_max_conns_factor: 1.0
+ apikey_max_conns_factor: 1.0
+ dial_timeout_seconds: 10
+ read_timeout_seconds: 900
+ write_timeout_seconds: 120
+ pool_target_utilization: 0.7
+ queue_limit_per_conn: 64
+ # 流式写出批量 flush 参数
+ event_flush_batch_size: 1
+ event_flush_interval_ms: 10
+ # 预热触发冷却(毫秒)
+ prewarm_cooldown_ms: 300
+ # WS 回退到 HTTP 后的冷却时间(秒),用于避免 WS/HTTP 来回抖动;0 表示关闭冷却
+ fallback_cooldown_seconds: 30
+ # WS 重试退避参数(毫秒)
+ retry_backoff_initial_ms: 120
+ retry_backoff_max_ms: 2000
+ # 抖动比例(0-1)
+ retry_jitter_ratio: 0.2
+ # 单次请求 WS 重试总预算(毫秒);建议设置为有限值,避免重试拉高 TTFT 长尾
+ retry_total_budget_ms: 5000
+ # payload_schema 日志采样率(0-1);降低热路径日志放大
+ payload_log_sample_rate: 0.2
+ # 调度与粘连参数
+ lb_top_k: 7
+ sticky_session_ttl_seconds: 3600
+ # 会话哈希迁移兼容开关:新 key 未命中时回退读取旧 SHA-256 key
+ session_hash_read_old_fallback: true
+ # 会话哈希迁移兼容开关:写入时双写旧 SHA-256 key(短 TTL)
+ session_hash_dual_write_old: true
+ # context 元数据迁移兼容开关:保留旧 ctxkey.* 读取/注入桥接
+ metadata_bridge_enabled: true
+ sticky_response_id_ttl_seconds: 3600
+ # 兼容旧键:当 sticky_response_id_ttl_seconds 缺失时回退该值
+ sticky_previous_response_ttl_seconds: 3600
+ scheduler_score_weights:
+ priority: 1.0
+ load: 1.0
+ queue: 0.7
+ error_rate: 0.8
+ ttft: 0.5
# HTTP upstream connection pool settings (HTTP/2 + multi-proxy scenario defaults)
# HTTP 上游连接池配置(HTTP/2 + 多代理场景默认值)
# Max idle connections across all hosts
@@ -779,12 +856,12 @@ rate_limit:
# 定价数据源(可选)
# =============================================================================
pricing:
- # URL to fetch model pricing data (default: LiteLLM)
- # 获取模型定价数据的 URL(默认:LiteLLM)
- remote_url: "https://github.com/Wei-Shaw/model-price-repo/raw/refs/heads/main/model_prices_and_context_window.json"
+ # URL to fetch model pricing data (default: pinned model-price-repo commit)
+ # 获取模型定价数据的 URL(默认:固定 commit 的 model-price-repo)
+ remote_url: "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/c7947e9871687e664180bc971d4837f1fc2784a9/model_prices_and_context_window.json"
# Hash verification URL (optional)
# 哈希校验 URL(可选)
- hash_url: "https://github.com/Wei-Shaw/model-price-repo/raw/refs/heads/main/model_prices_and_context_window.sha256"
+ hash_url: "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/c7947e9871687e664180bc971d4837f1fc2784a9/model_prices_and_context_window.sha256"
# Local data directory for caching
# 本地数据缓存目录
data_dir: "./data"
diff --git a/deploy/docker-compose.override.yml.example b/deploy/docker-compose.override.yml.example
index 297724f5..7157f212 100644
--- a/deploy/docker-compose.override.yml.example
+++ b/deploy/docker-compose.override.yml.example
@@ -127,6 +127,19 @@ services:
# - ./logs:/app/logs
# - ./backups:/app/backups
+# =============================================================================
+# Scenario 6: 启用宿主机 datamanagementd(数据管理)
+# =============================================================================
+# 说明:
+# - datamanagementd 运行在宿主机(systemd 或手动)
+# - 主进程固定探测 /tmp/sub2api-datamanagement.sock
+# - 需要把宿主机 socket 挂载到容器内同路径
+#
+# services:
+# sub2api:
+# volumes:
+# - /tmp/sub2api-datamanagement.sock:/tmp/sub2api-datamanagement.sock
+
# =============================================================================
# Additional Notes
# =============================================================================
diff --git a/deploy/install-datamanagementd.sh b/deploy/install-datamanagementd.sh
new file mode 100755
index 00000000..8d53134b
--- /dev/null
+++ b/deploy/install-datamanagementd.sh
@@ -0,0 +1,123 @@
+#!/usr/bin/env bash
+
+set -euo pipefail
+
+# 用法:
+# sudo ./install-datamanagementd.sh --binary /path/to/datamanagementd
+# 或:
+# sudo ./install-datamanagementd.sh --source /path/to/sub2api/repo
+
+BIN_PATH=""
+SOURCE_PATH=""
+INSTALL_DIR="/opt/sub2api"
+DATA_DIR="/var/lib/sub2api/datamanagement"
+SERVICE_FILE_NAME="sub2api-datamanagementd.service"
+
+function print_help() {
+ cat <<'EOF'
+用法:
+ install-datamanagementd.sh [--binary ] [--source <仓库路径>]
+
+参数:
+ --binary 指定已构建的 datamanagementd 二进制路径
+ --source 指定 sub2api 仓库路径(脚本会执行 go build)
+ -h, --help 显示帮助
+
+示例:
+ sudo ./install-datamanagementd.sh --binary ./datamanagement/datamanagementd
+ sudo ./install-datamanagementd.sh --source /opt/sub2api-src
+EOF
+}
+
+while [[ $# -gt 0 ]]; do
+ case "$1" in
+ --binary)
+ BIN_PATH="${2:-}"
+ shift 2
+ ;;
+ --source)
+ SOURCE_PATH="${2:-}"
+ shift 2
+ ;;
+ -h|--help)
+ print_help
+ exit 0
+ ;;
+ *)
+ echo "未知参数: $1"
+ print_help
+ exit 1
+ ;;
+ esac
+done
+
+if [[ -n "$BIN_PATH" && -n "$SOURCE_PATH" ]]; then
+ echo "错误: --binary 与 --source 只能二选一"
+ exit 1
+fi
+
+if [[ -z "$BIN_PATH" && -z "$SOURCE_PATH" ]]; then
+ echo "错误: 必须提供 --binary 或 --source"
+ exit 1
+fi
+
+if [[ "$(id -u)" -ne 0 ]]; then
+ echo "错误: 请使用 root 权限执行(例如 sudo)"
+ exit 1
+fi
+
+if [[ -n "$SOURCE_PATH" ]]; then
+ if [[ ! -d "$SOURCE_PATH/datamanagement" ]]; then
+ echo "错误: 无效仓库路径,未找到 $SOURCE_PATH/datamanagement"
+ exit 1
+ fi
+ echo "[1/6] 从源码构建 datamanagementd..."
+ (cd "$SOURCE_PATH/datamanagement" && go build -o datamanagementd ./cmd/datamanagementd)
+ BIN_PATH="$SOURCE_PATH/datamanagement/datamanagementd"
+fi
+
+if [[ ! -f "$BIN_PATH" ]]; then
+ echo "错误: 二进制文件不存在: $BIN_PATH"
+ exit 1
+fi
+
+if ! id sub2api >/dev/null 2>&1; then
+ echo "[2/6] 创建系统用户 sub2api..."
+ useradd --system --no-create-home --shell /usr/sbin/nologin sub2api
+else
+ echo "[2/6] 系统用户 sub2api 已存在,跳过创建"
+fi
+
+echo "[3/6] 安装 datamanagementd 二进制..."
+mkdir -p "$INSTALL_DIR"
+install -m 0755 "$BIN_PATH" "$INSTALL_DIR/datamanagementd"
+
+echo "[4/6] 准备数据目录..."
+mkdir -p "$DATA_DIR"
+chown -R sub2api:sub2api /var/lib/sub2api
+chmod 0750 "$DATA_DIR"
+
+SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+SERVICE_TEMPLATE="$SCRIPT_DIR/$SERVICE_FILE_NAME"
+if [[ ! -f "$SERVICE_TEMPLATE" ]]; then
+ echo "错误: 未找到服务模板 $SERVICE_TEMPLATE"
+ exit 1
+fi
+
+echo "[5/6] 安装 systemd 服务..."
+cp "$SERVICE_TEMPLATE" "/etc/systemd/system/$SERVICE_FILE_NAME"
+systemctl daemon-reload
+systemctl enable --now sub2api-datamanagementd
+
+echo "[6/6] 完成,当前状态:"
+systemctl --no-pager --full status sub2api-datamanagementd || true
+
+cat <<'EOF'
+
+下一步建议:
+1. 查看日志:sudo journalctl -u sub2api-datamanagementd -f
+2. 在 sub2api(容器部署时)挂载 socket:
+ /tmp/sub2api-datamanagement.sock:/tmp/sub2api-datamanagement.sock
+3. 进入管理后台“数据管理”页面确认 agent=enabled
+
+EOF
diff --git a/deploy/sub2api-datamanagementd.service b/deploy/sub2api-datamanagementd.service
new file mode 100644
index 00000000..b32733b7
--- /dev/null
+++ b/deploy/sub2api-datamanagementd.service
@@ -0,0 +1,22 @@
+[Unit]
+Description=Sub2API Data Management Daemon
+After=network.target
+Wants=network.target
+
+[Service]
+Type=simple
+User=sub2api
+Group=sub2api
+WorkingDirectory=/opt/sub2api
+ExecStart=/opt/sub2api/datamanagementd \
+ -socket-path /tmp/sub2api-datamanagement.sock \
+ -sqlite-path /var/lib/sub2api/datamanagement/datamanagementd.db \
+ -version 1.0.0
+Restart=always
+RestartSec=5s
+LimitNOFILE=100000
+NoNewPrivileges=true
+PrivateTmp=false
+
+[Install]
+WantedBy=multi-user.target
diff --git a/docs/ADMIN_PAYMENT_INTEGRATION_API.md b/docs/ADMIN_PAYMENT_INTEGRATION_API.md
new file mode 100644
index 00000000..4cc21594
--- /dev/null
+++ b/docs/ADMIN_PAYMENT_INTEGRATION_API.md
@@ -0,0 +1,241 @@
+# ADMIN_PAYMENT_INTEGRATION_API
+
+> 单文件中英双语文档 / Single-file bilingual documentation (Chinese + English)
+
+---
+
+## 中文
+
+### 目标
+本文档用于对接外部支付系统(如 `sub2apipay`)与 Sub2API 的 Admin API,覆盖:
+- 支付成功后充值
+- 用户查询
+- 人工余额修正
+- 前端购买页参数透传
+
+### 基础地址
+- 生产:`https://`
+- Beta:`http://:8084`
+
+### 认证
+推荐使用:
+- `x-api-key: admin-<64hex>`
+- `Content-Type: application/json`
+- 幂等接口额外传:`Idempotency-Key`
+
+说明:管理员 JWT 也可访问 admin 路由,但服务间调用建议使用 Admin API Key。
+
+### 1) 一步完成创建并兑换
+`POST /api/v1/admin/redeem-codes/create-and-redeem`
+
+用途:原子完成“创建兑换码 + 兑换到指定用户”。
+
+请求头:
+- `x-api-key`
+- `Idempotency-Key`
+
+请求体示例:
+```json
+{
+ "code": "s2p_cm1234567890",
+ "type": "balance",
+ "value": 100.0,
+ "user_id": 123,
+ "notes": "sub2apipay order: cm1234567890"
+}
+```
+
+幂等语义:
+- 同 `code` 且 `used_by` 一致:`200`
+- 同 `code` 但 `used_by` 不一致:`409`
+- 缺少 `Idempotency-Key`:`400`(`IDEMPOTENCY_KEY_REQUIRED`)
+
+curl 示例:
+```bash
+curl -X POST "${BASE}/api/v1/admin/redeem-codes/create-and-redeem" \
+ -H "x-api-key: ${KEY}" \
+ -H "Idempotency-Key: pay-cm1234567890-success" \
+ -H "Content-Type: application/json" \
+ -d '{
+ "code":"s2p_cm1234567890",
+ "type":"balance",
+ "value":100.00,
+ "user_id":123,
+ "notes":"sub2apipay order: cm1234567890"
+ }'
+```
+
+### 2) 查询用户(可选前置校验)
+`GET /api/v1/admin/users/:id`
+
+```bash
+curl -s "${BASE}/api/v1/admin/users/123" \
+ -H "x-api-key: ${KEY}"
+```
+
+### 3) 余额调整(已有接口)
+`POST /api/v1/admin/users/:id/balance`
+
+用途:人工补偿 / 扣减,支持 `set` / `add` / `subtract`。
+
+请求体示例(扣减):
+```json
+{
+ "balance": 100.0,
+ "operation": "subtract",
+ "notes": "manual correction"
+}
+```
+
+```bash
+curl -X POST "${BASE}/api/v1/admin/users/123/balance" \
+ -H "x-api-key: ${KEY}" \
+ -H "Idempotency-Key: balance-subtract-cm1234567890" \
+ -H "Content-Type: application/json" \
+ -d '{
+ "balance":100.00,
+ "operation":"subtract",
+ "notes":"manual correction"
+ }'
+```
+
+### 4) 购买页 URL Query 透传(iframe / 新窗口一致)
+当 Sub2API 打开 `purchase_subscription_url` 时,会统一追加:
+- `user_id`
+- `token`
+- `theme`(`light` / `dark`)
+- `ui_mode`(固定 `embedded`)
+
+示例:
+```text
+https://pay.example.com/pay?user_id=123&token=&theme=light&ui_mode=embedded
+```
+
+### 5) 失败处理建议
+- 支付成功与充值成功分状态落库
+- 回调验签成功后立即标记“支付成功”
+- 支付成功但充值失败的订单允许后续重试
+- 重试保持相同 `code`,并使用新的 `Idempotency-Key`
+
+### 6) `doc_url` 配置建议
+- 查看链接:`https://github.com/Wei-Shaw/sub2api/blob/main/ADMIN_PAYMENT_INTEGRATION_API.md`
+- 下载链接:`https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/ADMIN_PAYMENT_INTEGRATION_API.md`
+
+---
+
+## English
+
+### Purpose
+This document describes the minimal Sub2API Admin API surface for external payment integrations (for example, `sub2apipay`), including:
+- Recharge after payment success
+- User lookup
+- Manual balance correction
+- Purchase page query parameter forwarding
+
+### Base URL
+- Production: `https://`
+- Beta: `http://:8084`
+
+### Authentication
+Recommended headers:
+- `x-api-key: admin-<64hex>`
+- `Content-Type: application/json`
+- `Idempotency-Key` for idempotent endpoints
+
+Note: Admin JWT can also access admin routes, but Admin API Key is recommended for server-to-server integration.
+
+### 1) Create and Redeem in one step
+`POST /api/v1/admin/redeem-codes/create-and-redeem`
+
+Use case: atomically create a redeem code and redeem it to a target user.
+
+Headers:
+- `x-api-key`
+- `Idempotency-Key`
+
+Request body:
+```json
+{
+ "code": "s2p_cm1234567890",
+ "type": "balance",
+ "value": 100.0,
+ "user_id": 123,
+ "notes": "sub2apipay order: cm1234567890"
+}
+```
+
+Idempotency behavior:
+- Same `code` and same `used_by`: `200`
+- Same `code` but different `used_by`: `409`
+- Missing `Idempotency-Key`: `400` (`IDEMPOTENCY_KEY_REQUIRED`)
+
+curl example:
+```bash
+curl -X POST "${BASE}/api/v1/admin/redeem-codes/create-and-redeem" \
+ -H "x-api-key: ${KEY}" \
+ -H "Idempotency-Key: pay-cm1234567890-success" \
+ -H "Content-Type: application/json" \
+ -d '{
+ "code":"s2p_cm1234567890",
+ "type":"balance",
+ "value":100.00,
+ "user_id":123,
+ "notes":"sub2apipay order: cm1234567890"
+ }'
+```
+
+### 2) Query User (optional pre-check)
+`GET /api/v1/admin/users/:id`
+
+```bash
+curl -s "${BASE}/api/v1/admin/users/123" \
+ -H "x-api-key: ${KEY}"
+```
+
+### 3) Balance Adjustment (existing API)
+`POST /api/v1/admin/users/:id/balance`
+
+Use case: manual correction with `set` / `add` / `subtract`.
+
+Request body example (`subtract`):
+```json
+{
+ "balance": 100.0,
+ "operation": "subtract",
+ "notes": "manual correction"
+}
+```
+
+```bash
+curl -X POST "${BASE}/api/v1/admin/users/123/balance" \
+ -H "x-api-key: ${KEY}" \
+ -H "Idempotency-Key: balance-subtract-cm1234567890" \
+ -H "Content-Type: application/json" \
+ -d '{
+ "balance":100.00,
+ "operation":"subtract",
+ "notes":"manual correction"
+ }'
+```
+
+### 4) Purchase URL query forwarding (iframe and new tab)
+When Sub2API opens `purchase_subscription_url`, it appends:
+- `user_id`
+- `token`
+- `theme` (`light` / `dark`)
+- `ui_mode` (fixed: `embedded`)
+
+Example:
+```text
+https://pay.example.com/pay?user_id=123&token=&theme=light&ui_mode=embedded
+```
+
+### 5) Failure handling recommendations
+- Persist payment success and recharge success as separate states
+- Mark payment as successful immediately after verified callback
+- Allow retry for orders with payment success but recharge failure
+- Keep the same `code` for retry, and use a new `Idempotency-Key`
+
+### 6) Recommended `doc_url`
+- View URL: `https://github.com/Wei-Shaw/sub2api/blob/main/ADMIN_PAYMENT_INTEGRATION_API.md`
+- Download URL: `https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/ADMIN_PAYMENT_INTEGRATION_API.md`
diff --git a/frontend/src/api/__tests__/sora.spec.ts b/frontend/src/api/__tests__/sora.spec.ts
new file mode 100644
index 00000000..88c0c416
--- /dev/null
+++ b/frontend/src/api/__tests__/sora.spec.ts
@@ -0,0 +1,80 @@
+import { describe, expect, it } from 'vitest'
+import {
+ normalizeGenerationListResponse,
+ normalizeModelFamiliesResponse
+} from '../sora'
+
+describe('sora api normalizers', () => {
+ it('normalizes generation list from data shape', () => {
+ const result = normalizeGenerationListResponse({
+ data: [{ id: 1, status: 'pending' }],
+ total: 9,
+ page: 2
+ })
+
+ expect(result.data).toHaveLength(1)
+ expect(result.total).toBe(9)
+ expect(result.page).toBe(2)
+ })
+
+ it('normalizes generation list from items shape', () => {
+ const result = normalizeGenerationListResponse({
+ items: [{ id: 1, status: 'completed' }],
+ total: 1
+ })
+
+ expect(result.data).toHaveLength(1)
+ expect(result.total).toBe(1)
+ expect(result.page).toBe(1)
+ })
+
+ it('falls back to empty generation list on invalid payload', () => {
+ const result = normalizeGenerationListResponse(null)
+ expect(result).toEqual({ data: [], total: 0, page: 1 })
+ })
+
+ it('normalizes family model payload', () => {
+ const result = normalizeModelFamiliesResponse({
+ data: [
+ {
+ id: 'sora2',
+ name: 'Sora 2',
+ type: 'video',
+ orientations: ['landscape', 'portrait'],
+ durations: [10, 15]
+ }
+ ]
+ })
+
+ expect(result).toHaveLength(1)
+ expect(result[0].id).toBe('sora2')
+ expect(result[0].orientations).toEqual(['landscape', 'portrait'])
+ expect(result[0].durations).toEqual([10, 15])
+ })
+
+ it('normalizes legacy flat model list into families', () => {
+ const result = normalizeModelFamiliesResponse({
+ items: [
+ { id: 'sora2-landscape-10s', type: 'video' },
+ { id: 'sora2-portrait-15s', type: 'video' },
+ { id: 'gpt-image-square', type: 'image' }
+ ]
+ })
+
+ const sora2 = result.find((m) => m.id === 'sora2')
+ expect(sora2).toBeTruthy()
+ expect(sora2?.orientations).toEqual(['landscape', 'portrait'])
+ expect(sora2?.durations).toEqual([10, 15])
+
+ const image = result.find((m) => m.id === 'gpt-image')
+ expect(image).toBeTruthy()
+ expect(image?.type).toBe('image')
+ expect(image?.orientations).toEqual(['square'])
+ })
+
+ it('falls back to empty families on invalid payload', () => {
+ expect(normalizeModelFamiliesResponse(undefined)).toEqual([])
+ expect(normalizeModelFamiliesResponse({})).toEqual([])
+ })
+})
+
diff --git a/frontend/src/api/admin/accounts.ts b/frontend/src/api/admin/accounts.ts
index 1b8ae9ad..56571699 100644
--- a/frontend/src/api/admin/accounts.ts
+++ b/frontend/src/api/admin/accounts.ts
@@ -369,6 +369,22 @@ export async function getTodayStats(id: number): Promise {
return data
}
+export interface BatchTodayStatsResponse {
+ stats: Record
+}
+
+/**
+ * 批量获取多个账号的今日统计
+ * @param accountIds - 账号 ID 列表
+ * @returns 以账号 ID(字符串)为键的统计映射
+ */
+export async function getBatchTodayStats(accountIds: number[]): Promise {
+ const { data } = await apiClient.post('/admin/accounts/today-stats/batch', {
+ account_ids: accountIds
+ })
+ return data
+}
+
/**
* Set account schedulable status
* @param id - Account ID
@@ -556,6 +572,7 @@ export const accountsAPI = {
clearError,
getUsage,
getTodayStats,
+ getBatchTodayStats,
clearRateLimit,
getTempUnschedulableStatus,
resetTempUnschedulable,
diff --git a/frontend/src/api/admin/apiKeys.ts b/frontend/src/api/admin/apiKeys.ts
new file mode 100644
index 00000000..79f6e174
--- /dev/null
+++ b/frontend/src/api/admin/apiKeys.ts
@@ -0,0 +1,33 @@
+/**
+ * Admin API Keys API endpoints
+ * Handles API key management for administrators
+ */
+
+import { apiClient } from '../client'
+import type { ApiKey } from '@/types'
+
+export interface UpdateApiKeyGroupResult {
+ api_key: ApiKey
+ auto_granted_group_access: boolean
+ granted_group_id?: number
+ granted_group_name?: string
+}
+
+/**
+ * Update an API key's group binding
+ * @param id - API Key ID
+ * @param groupId - Group ID (0 to unbind, positive to bind, null/undefined to skip)
+ * @returns Updated API key with auto-grant info
+ */
+export async function updateApiKeyGroup(id: number, groupId: number | null): Promise {
+ const { data } = await apiClient.put(`/admin/api-keys/${id}`, {
+ group_id: groupId === null ? 0 : groupId
+ })
+ return data
+}
+
+export const apiKeysAPI = {
+ updateApiKeyGroup
+}
+
+export default apiKeysAPI
diff --git a/frontend/src/api/admin/dashboard.ts b/frontend/src/api/admin/dashboard.ts
index ae48bec2..54bd92a4 100644
--- a/frontend/src/api/admin/dashboard.ts
+++ b/frontend/src/api/admin/dashboard.ts
@@ -8,8 +8,10 @@ import type {
DashboardStats,
TrendDataPoint,
ModelStat,
+ GroupStat,
ApiKeyUsageTrendPoint,
- UserUsageTrendPoint
+ UserUsageTrendPoint,
+ UsageRequestType
} from '@/types'
/**
@@ -49,6 +51,7 @@ export interface TrendParams {
model?: string
account_id?: number
group_id?: number
+ request_type?: UsageRequestType
stream?: boolean
billing_type?: number | null
}
@@ -78,6 +81,7 @@ export interface ModelStatsParams {
model?: string
account_id?: number
group_id?: number
+ request_type?: UsageRequestType
stream?: boolean
billing_type?: number | null
}
@@ -98,6 +102,34 @@ export async function getModelStats(params?: ModelStatsParams): Promise {
+ const { data } = await apiClient.get('/admin/dashboard/groups', { params })
+ return data
+}
+
export interface ApiKeyTrendParams extends TrendParams {
limit?: number
}
@@ -200,6 +232,7 @@ export const dashboardAPI = {
getRealtimeMetrics,
getUsageTrend,
getModelStats,
+ getGroupStats,
getApiKeyUsageTrend,
getUserUsageTrend,
getBatchUsersUsage,
diff --git a/frontend/src/api/admin/dataManagement.ts b/frontend/src/api/admin/dataManagement.ts
new file mode 100644
index 00000000..cec71446
--- /dev/null
+++ b/frontend/src/api/admin/dataManagement.ts
@@ -0,0 +1,332 @@
+import { apiClient } from '../client'
+
+export type BackupType = 'postgres' | 'redis' | 'full'
+export type BackupJobStatus = 'queued' | 'running' | 'succeeded' | 'failed' | 'partial_succeeded'
+
+export interface BackupAgentInfo {
+ status: string
+ version: string
+ uptime_seconds: number
+}
+
+export interface BackupAgentHealth {
+ enabled: boolean
+ reason: string
+ socket_path: string
+ agent?: BackupAgentInfo
+}
+
+export interface DataManagementPostgresConfig {
+ host: string
+ port: number
+ user: string
+ password?: string
+ password_configured?: boolean
+ database: string
+ ssl_mode: string
+ container_name: string
+}
+
+export interface DataManagementRedisConfig {
+ addr: string
+ username: string
+ password?: string
+ password_configured?: boolean
+ db: number
+ container_name: string
+}
+
+export interface DataManagementS3Config {
+ enabled: boolean
+ endpoint: string
+ region: string
+ bucket: string
+ access_key_id: string
+ secret_access_key?: string
+ secret_access_key_configured?: boolean
+ prefix: string
+ force_path_style: boolean
+ use_ssl: boolean
+}
+
+export interface DataManagementConfig {
+ source_mode: 'direct' | 'docker_exec'
+ backup_root: string
+ sqlite_path?: string
+ retention_days: number
+ keep_last: number
+ active_postgres_profile_id?: string
+ active_redis_profile_id?: string
+ active_s3_profile_id?: string
+ postgres: DataManagementPostgresConfig
+ redis: DataManagementRedisConfig
+ s3: DataManagementS3Config
+}
+
+export type SourceType = 'postgres' | 'redis'
+
+export interface DataManagementSourceConfig {
+ host: string
+ port: number
+ user: string
+ password?: string
+ database: string
+ ssl_mode: string
+ addr: string
+ username: string
+ db: number
+ container_name: string
+}
+
+export interface DataManagementSourceProfile {
+ source_type: SourceType
+ profile_id: string
+ name: string
+ is_active: boolean
+ password_configured?: boolean
+ config: DataManagementSourceConfig
+ created_at?: string
+ updated_at?: string
+}
+
+export interface TestS3Request {
+ endpoint: string
+ region: string
+ bucket: string
+ access_key_id: string
+ secret_access_key: string
+ prefix?: string
+ force_path_style?: boolean
+ use_ssl?: boolean
+}
+
+export interface TestS3Response {
+ ok: boolean
+ message: string
+}
+
+export interface CreateBackupJobRequest {
+ backup_type: BackupType
+ upload_to_s3?: boolean
+ s3_profile_id?: string
+ postgres_profile_id?: string
+ redis_profile_id?: string
+ idempotency_key?: string
+}
+
+export interface CreateBackupJobResponse {
+ job_id: string
+ status: BackupJobStatus
+}
+
+export interface BackupArtifactInfo {
+ local_path: string
+ size_bytes: number
+ sha256: string
+}
+
+export interface BackupS3Info {
+ bucket: string
+ key: string
+ etag: string
+}
+
+export interface BackupJob {
+ job_id: string
+ backup_type: BackupType
+ status: BackupJobStatus
+ triggered_by: string
+ s3_profile_id?: string
+ postgres_profile_id?: string
+ redis_profile_id?: string
+ started_at?: string
+ finished_at?: string
+ error_message?: string
+ artifact?: BackupArtifactInfo
+ s3?: BackupS3Info
+}
+
+export interface ListSourceProfilesResponse {
+ items: DataManagementSourceProfile[]
+}
+
+export interface CreateSourceProfileRequest {
+ profile_id: string
+ name: string
+ config: DataManagementSourceConfig
+ set_active?: boolean
+}
+
+export interface UpdateSourceProfileRequest {
+ name: string
+ config: DataManagementSourceConfig
+}
+
+export interface DataManagementS3Profile {
+ profile_id: string
+ name: string
+ is_active: boolean
+ s3: DataManagementS3Config
+ secret_access_key_configured?: boolean
+ created_at?: string
+ updated_at?: string
+}
+
+export interface ListS3ProfilesResponse {
+ items: DataManagementS3Profile[]
+}
+
+export interface CreateS3ProfileRequest {
+ profile_id: string
+ name: string
+ enabled: boolean
+ endpoint: string
+ region: string
+ bucket: string
+ access_key_id: string
+ secret_access_key?: string
+ prefix?: string
+ force_path_style?: boolean
+ use_ssl?: boolean
+ set_active?: boolean
+}
+
+export interface UpdateS3ProfileRequest {
+ name: string
+ enabled: boolean
+ endpoint: string
+ region: string
+ bucket: string
+ access_key_id: string
+ secret_access_key?: string
+ prefix?: string
+ force_path_style?: boolean
+ use_ssl?: boolean
+}
+
+export interface ListBackupJobsRequest {
+ page_size?: number
+ page_token?: string
+ status?: BackupJobStatus
+ backup_type?: BackupType
+}
+
+export interface ListBackupJobsResponse {
+ items: BackupJob[]
+ next_page_token?: string
+}
+
+export async function getAgentHealth(): Promise {
+ const { data } = await apiClient.get('/admin/data-management/agent/health')
+ return data
+}
+
+export async function getConfig(): Promise {
+ const { data } = await apiClient.get('/admin/data-management/config')
+ return data
+}
+
+export async function updateConfig(request: DataManagementConfig): Promise {
+ const { data } = await apiClient.put('/admin/data-management/config', request)
+ return data
+}
+
+export async function testS3(request: TestS3Request): Promise {
+ const { data } = await apiClient.post('/admin/data-management/s3/test', request)
+ return data
+}
+
+export async function listSourceProfiles(sourceType: SourceType): Promise {
+ const { data } = await apiClient.get(`/admin/data-management/sources/${sourceType}/profiles`)
+ return data
+}
+
+export async function createSourceProfile(sourceType: SourceType, request: CreateSourceProfileRequest): Promise {
+ const { data } = await apiClient.post(`/admin/data-management/sources/${sourceType}/profiles`, request)
+ return data
+}
+
+export async function updateSourceProfile(sourceType: SourceType, profileID: string, request: UpdateSourceProfileRequest): Promise {
+ const { data } = await apiClient.put(`/admin/data-management/sources/${sourceType}/profiles/${profileID}`, request)
+ return data
+}
+
+export async function deleteSourceProfile(sourceType: SourceType, profileID: string): Promise {
+ await apiClient.delete(`/admin/data-management/sources/${sourceType}/profiles/${profileID}`)
+}
+
+export async function setActiveSourceProfile(sourceType: SourceType, profileID: string): Promise {
+ const { data } = await apiClient.post(`/admin/data-management/sources/${sourceType}/profiles/${profileID}/activate`)
+ return data
+}
+
+export async function listS3Profiles(): Promise {
+ const { data } = await apiClient.get('/admin/data-management/s3/profiles')
+ return data
+}
+
+export async function createS3Profile(request: CreateS3ProfileRequest): Promise {
+ const { data } = await apiClient.post('/admin/data-management/s3/profiles', request)
+ return data
+}
+
+export async function updateS3Profile(profileID: string, request: UpdateS3ProfileRequest): Promise {
+ const { data } = await apiClient.put(`/admin/data-management/s3/profiles/${profileID}`, request)
+ return data
+}
+
+export async function deleteS3Profile(profileID: string): Promise {
+ await apiClient.delete(`/admin/data-management/s3/profiles/${profileID}`)
+}
+
+export async function setActiveS3Profile(profileID: string): Promise {
+ const { data } = await apiClient.post(`/admin/data-management/s3/profiles/${profileID}/activate`)
+ return data
+}
+
+export async function createBackupJob(request: CreateBackupJobRequest): Promise {
+ const headers = request.idempotency_key
+ ? { 'X-Idempotency-Key': request.idempotency_key }
+ : undefined
+
+ const { data } = await apiClient.post(
+ '/admin/data-management/backups',
+ request,
+ { headers }
+ )
+ return data
+}
+
+export async function listBackupJobs(request?: ListBackupJobsRequest): Promise {
+ const { data } = await apiClient.get('/admin/data-management/backups', {
+ params: request
+ })
+ return data
+}
+
+export async function getBackupJob(jobID: string): Promise {
+ const { data } = await apiClient.get(`/admin/data-management/backups/${jobID}`)
+ return data
+}
+
+export const dataManagementAPI = {
+ getAgentHealth,
+ getConfig,
+ updateConfig,
+ listSourceProfiles,
+ createSourceProfile,
+ updateSourceProfile,
+ deleteSourceProfile,
+ setActiveSourceProfile,
+ testS3,
+ listS3Profiles,
+ createS3Profile,
+ updateS3Profile,
+ deleteS3Profile,
+ setActiveS3Profile,
+ createBackupJob,
+ listBackupJobs,
+ getBackupJob
+}
+
+export default dataManagementAPI
diff --git a/frontend/src/api/admin/index.ts b/frontend/src/api/admin/index.ts
index ffb9b179..5db998e5 100644
--- a/frontend/src/api/admin/index.ts
+++ b/frontend/src/api/admin/index.ts
@@ -20,6 +20,8 @@ import antigravityAPI from './antigravity'
import userAttributesAPI from './userAttributes'
import opsAPI from './ops'
import errorPassthroughAPI from './errorPassthrough'
+import dataManagementAPI from './dataManagement'
+import apiKeysAPI from './apiKeys'
/**
* Unified admin API object for convenient access
@@ -41,7 +43,9 @@ export const adminAPI = {
antigravity: antigravityAPI,
userAttributes: userAttributesAPI,
ops: opsAPI,
- errorPassthrough: errorPassthroughAPI
+ errorPassthrough: errorPassthroughAPI,
+ dataManagement: dataManagementAPI,
+ apiKeys: apiKeysAPI
}
export {
@@ -61,7 +65,9 @@ export {
antigravityAPI,
userAttributesAPI,
opsAPI,
- errorPassthroughAPI
+ errorPassthroughAPI,
+ dataManagementAPI,
+ apiKeysAPI
}
export default adminAPI
@@ -69,3 +75,4 @@ export default adminAPI
// Re-export types used by components
export type { BalanceHistoryItem } from './users'
export type { ErrorPassthroughRule, CreateRuleRequest, UpdateRuleRequest } from './errorPassthrough'
+export type { BackupAgentHealth, DataManagementConfig } from './dataManagement'
diff --git a/frontend/src/api/admin/settings.ts b/frontend/src/api/admin/settings.ts
index 3dc76fe7..c1b767ba 100644
--- a/frontend/src/api/admin/settings.ts
+++ b/frontend/src/api/admin/settings.ts
@@ -5,6 +5,11 @@
import { apiClient } from '../client'
+export interface DefaultSubscriptionSetting {
+ group_id: number
+ validity_days: number
+}
+
/**
* System settings interface
*/
@@ -20,6 +25,7 @@ export interface SystemSettings {
// Default settings
default_balance: number
default_concurrency: number
+ default_subscriptions: DefaultSubscriptionSetting[]
// OEM settings
site_name: string
site_logo: string
@@ -31,6 +37,7 @@ export interface SystemSettings {
hide_ccs_import_button: boolean
purchase_subscription_enabled: boolean
purchase_subscription_url: string
+ sora_client_enabled: boolean
// SMTP settings
smtp_host: string
smtp_port: number
@@ -66,6 +73,9 @@ export interface SystemSettings {
ops_realtime_monitoring_enabled: boolean
ops_query_mode_default: 'auto' | 'raw' | 'preagg' | string
ops_metrics_interval_seconds: number
+
+ // Claude Code version check
+ min_claude_code_version: string
}
export interface UpdateSettingsRequest {
@@ -77,6 +87,7 @@ export interface UpdateSettingsRequest {
totp_enabled?: boolean // TOTP 双因素认证
default_balance?: number
default_concurrency?: number
+ default_subscriptions?: DefaultSubscriptionSetting[]
site_name?: string
site_logo?: string
site_subtitle?: string
@@ -87,6 +98,7 @@ export interface UpdateSettingsRequest {
hide_ccs_import_button?: boolean
purchase_subscription_enabled?: boolean
purchase_subscription_url?: string
+ sora_client_enabled?: boolean
smtp_host?: string
smtp_port?: number
smtp_username?: string
@@ -112,6 +124,7 @@ export interface UpdateSettingsRequest {
ops_realtime_monitoring_enabled?: boolean
ops_query_mode_default?: 'auto' | 'raw' | 'preagg' | string
ops_metrics_interval_seconds?: number
+ min_claude_code_version?: string
}
/**
@@ -251,6 +264,142 @@ export async function updateStreamTimeoutSettings(
return data
}
+// ==================== Sora S3 Settings ====================
+
+export interface SoraS3Settings {
+ enabled: boolean
+ endpoint: string
+ region: string
+ bucket: string
+ access_key_id: string
+ secret_access_key_configured: boolean
+ prefix: string
+ force_path_style: boolean
+ cdn_url: string
+ default_storage_quota_bytes: number
+}
+
+export interface SoraS3Profile {
+ profile_id: string
+ name: string
+ is_active: boolean
+ enabled: boolean
+ endpoint: string
+ region: string
+ bucket: string
+ access_key_id: string
+ secret_access_key_configured: boolean
+ prefix: string
+ force_path_style: boolean
+ cdn_url: string
+ default_storage_quota_bytes: number
+ updated_at: string
+}
+
+export interface ListSoraS3ProfilesResponse {
+ active_profile_id: string
+ items: SoraS3Profile[]
+}
+
+export interface UpdateSoraS3SettingsRequest {
+ profile_id?: string
+ enabled: boolean
+ endpoint: string
+ region: string
+ bucket: string
+ access_key_id: string
+ secret_access_key?: string
+ prefix: string
+ force_path_style: boolean
+ cdn_url: string
+ default_storage_quota_bytes: number
+}
+
+export interface CreateSoraS3ProfileRequest {
+ profile_id: string
+ name: string
+ set_active?: boolean
+ enabled: boolean
+ endpoint: string
+ region: string
+ bucket: string
+ access_key_id: string
+ secret_access_key?: string
+ prefix: string
+ force_path_style: boolean
+ cdn_url: string
+ default_storage_quota_bytes: number
+}
+
+export interface UpdateSoraS3ProfileRequest {
+ name: string
+ enabled: boolean
+ endpoint: string
+ region: string
+ bucket: string
+ access_key_id: string
+ secret_access_key?: string
+ prefix: string
+ force_path_style: boolean
+ cdn_url: string
+ default_storage_quota_bytes: number
+}
+
+export interface TestSoraS3ConnectionRequest {
+ profile_id?: string
+ enabled: boolean
+ endpoint: string
+ region: string
+ bucket: string
+ access_key_id: string
+ secret_access_key?: string
+ prefix: string
+ force_path_style: boolean
+ cdn_url: string
+ default_storage_quota_bytes?: number
+}
+
+export async function getSoraS3Settings(): Promise {
+ const { data } = await apiClient.get('/admin/settings/sora-s3')
+ return data
+}
+
+export async function updateSoraS3Settings(settings: UpdateSoraS3SettingsRequest): Promise {
+ const { data } = await apiClient.put('/admin/settings/sora-s3', settings)
+ return data
+}
+
+export async function testSoraS3Connection(
+ settings: TestSoraS3ConnectionRequest
+): Promise<{ message: string }> {
+ const { data } = await apiClient.post<{ message: string }>('/admin/settings/sora-s3/test', settings)
+ return data
+}
+
+export async function listSoraS3Profiles(): Promise {
+ const { data } = await apiClient.get('/admin/settings/sora-s3/profiles')
+ return data
+}
+
+export async function createSoraS3Profile(request: CreateSoraS3ProfileRequest): Promise {
+ const { data } = await apiClient.post('/admin/settings/sora-s3/profiles', request)
+ return data
+}
+
+export async function updateSoraS3Profile(profileID: string, request: UpdateSoraS3ProfileRequest): Promise {
+ const { data } = await apiClient.put(`/admin/settings/sora-s3/profiles/${profileID}`, request)
+ return data
+}
+
+export async function deleteSoraS3Profile(profileID: string): Promise {
+ await apiClient.delete(`/admin/settings/sora-s3/profiles/${profileID}`)
+}
+
+export async function setActiveSoraS3Profile(profileID: string): Promise {
+ const { data } = await apiClient.post(`/admin/settings/sora-s3/profiles/${profileID}/activate`)
+ return data
+}
+
export const settingsAPI = {
getSettings,
updateSettings,
@@ -260,7 +409,15 @@ export const settingsAPI = {
regenerateAdminApiKey,
deleteAdminApiKey,
getStreamTimeoutSettings,
- updateStreamTimeoutSettings
+ updateStreamTimeoutSettings,
+ getSoraS3Settings,
+ updateSoraS3Settings,
+ testSoraS3Connection,
+ listSoraS3Profiles,
+ createSoraS3Profile,
+ updateSoraS3Profile,
+ deleteSoraS3Profile,
+ setActiveSoraS3Profile
}
export default settingsAPI
diff --git a/frontend/src/api/admin/usage.ts b/frontend/src/api/admin/usage.ts
index 94f7b57b..66c84410 100644
--- a/frontend/src/api/admin/usage.ts
+++ b/frontend/src/api/admin/usage.ts
@@ -4,7 +4,7 @@
*/
import { apiClient } from '../client'
-import type { AdminUsageLog, UsageQueryParams, PaginatedResponse } from '@/types'
+import type { AdminUsageLog, UsageQueryParams, PaginatedResponse, UsageRequestType } from '@/types'
// ==================== Types ====================
@@ -39,6 +39,7 @@ export interface UsageCleanupFilters {
account_id?: number
group_id?: number
model?: string | null
+ request_type?: UsageRequestType | null
stream?: boolean | null
billing_type?: number | null
}
@@ -66,6 +67,7 @@ export interface CreateUsageCleanupTaskRequest {
account_id?: number
group_id?: number
model?: string | null
+ request_type?: UsageRequestType | null
stream?: boolean | null
billing_type?: number | null
timezone?: string
@@ -104,6 +106,7 @@ export async function getStats(params: {
account_id?: number
group_id?: number
model?: string
+ request_type?: UsageRequestType
stream?: boolean
period?: string
start_date?: string
diff --git a/frontend/src/api/admin/users.ts b/frontend/src/api/admin/users.ts
index 287aef96..d36a2a5a 100644
--- a/frontend/src/api/admin/users.ts
+++ b/frontend/src/api/admin/users.ts
@@ -4,7 +4,7 @@
*/
import { apiClient } from '../client'
-import type { AdminUser, UpdateUserRequest, PaginatedResponse } from '@/types'
+import type { AdminUser, UpdateUserRequest, PaginatedResponse, ApiKey } from '@/types'
/**
* List all users with pagination
@@ -145,8 +145,8 @@ export async function toggleStatus(id: number, status: 'active' | 'disabled'): P
* @param id - User ID
* @returns List of user's API keys
*/
-export async function getUserApiKeys(id: number): Promise> {
- const { data } = await apiClient.get>(`/admin/users/${id}/api-keys`)
+export async function getUserApiKeys(id: number): Promise> {
+ const { data } = await apiClient.get>(`/admin/users/${id}/api-keys`)
return data
}
diff --git a/frontend/src/api/client.ts b/frontend/src/api/client.ts
index 22db5a44..95f9ff31 100644
--- a/frontend/src/api/client.ts
+++ b/frontend/src/api/client.ts
@@ -267,6 +267,7 @@ apiClient.interceptors.response.use(
return Promise.reject({
status,
code: apiData.code,
+ error: apiData.error,
message: apiData.message || apiData.detail || error.message
})
}
diff --git a/frontend/src/api/sora.ts b/frontend/src/api/sora.ts
new file mode 100644
index 00000000..45108454
--- /dev/null
+++ b/frontend/src/api/sora.ts
@@ -0,0 +1,307 @@
+/**
+ * Sora 客户端 API
+ * 封装所有 Sora 生成、作品库、配额等接口调用
+ */
+
+import { apiClient } from './client'
+
+// ==================== 类型定义 ====================
+
+export interface SoraGeneration {
+ id: number
+ user_id: number
+ model: string
+ prompt: string
+ media_type: string
+ status: string // pending | generating | completed | failed | cancelled
+ storage_type: string // upstream | s3 | local
+ media_url: string
+ media_urls: string[]
+ s3_object_keys: string[]
+ file_size_bytes: number
+ error_message: string
+ created_at: string
+ completed_at?: string
+}
+
+export interface GenerateRequest {
+ model: string
+ prompt: string
+ video_count?: number
+ media_type?: string
+ image_input?: string
+ api_key_id?: number
+}
+
+export interface GenerateResponse {
+ generation_id: number
+ status: string
+}
+
+export interface GenerationListResponse {
+ data: SoraGeneration[]
+ total: number
+ page: number
+}
+
+export interface QuotaInfo {
+ quota_bytes: number
+ used_bytes: number
+ available_bytes: number
+ quota_source: string // user | group | system | unlimited
+ source?: string // 兼容旧字段
+}
+
+export interface StorageStatus {
+ s3_enabled: boolean
+ s3_healthy: boolean
+ local_enabled: boolean
+}
+
+/** 单个扁平模型(旧接口,保留兼容) */
+export interface SoraModel {
+ id: string
+ name: string
+ type: string // video | image
+ orientation?: string
+ duration?: number
+}
+
+/** 模型家族(新接口 — 后端从 soraModelConfigs 自动聚合) */
+export interface SoraModelFamily {
+ id: string // 家族 ID,如 "sora2"
+ name: string // 显示名,如 "Sora 2"
+ type: string // "video" | "image"
+ orientations: string[] // ["landscape", "portrait"] 或 ["landscape", "portrait", "square"]
+ durations?: number[] // [10, 15, 25](仅视频模型)
+}
+
+type LooseRecord = Record
+
+function asRecord(value: unknown): LooseRecord | null {
+ return value !== null && typeof value === 'object' ? value as LooseRecord : null
+}
+
+function asArray(value: unknown): T[] {
+ return Array.isArray(value) ? value as T[] : []
+}
+
+function asPositiveInt(value: unknown): number | null {
+ const n = Number(value)
+ if (!Number.isFinite(n) || n <= 0) return null
+ return Math.round(n)
+}
+
+function dedupeStrings(values: string[]): string[] {
+ return Array.from(new Set(values))
+}
+
+function extractOrientationFromModelID(modelID: string): string | null {
+ const m = modelID.match(/-(landscape|portrait|square)(?:-\d+s)?$/i)
+ return m ? m[1].toLowerCase() : null
+}
+
+function extractDurationFromModelID(modelID: string): number | null {
+ const m = modelID.match(/-(\d+)s$/i)
+ return m ? asPositiveInt(m[1]) : null
+}
+
+function normalizeLegacyFamilies(candidates: unknown[]): SoraModelFamily[] {
+ const familyMap = new Map()
+
+ for (const item of candidates) {
+ const model = asRecord(item)
+ if (!model || typeof model.id !== 'string' || model.id.trim() === '') continue
+
+ const rawID = model.id.trim()
+ const type = model.type === 'image' ? 'image' : 'video'
+ const name = typeof model.name === 'string' && model.name.trim() ? model.name.trim() : rawID
+ const baseID = rawID.replace(/-(landscape|portrait|square)(?:-\d+s)?$/i, '')
+ const orientation =
+ typeof model.orientation === 'string' && model.orientation
+ ? model.orientation.toLowerCase()
+ : extractOrientationFromModelID(rawID)
+ const duration = asPositiveInt(model.duration) ?? extractDurationFromModelID(rawID)
+ const familyKey = baseID || rawID
+
+ const family = familyMap.get(familyKey) ?? {
+ id: familyKey,
+ name,
+ type,
+ orientations: [],
+ durations: []
+ }
+
+ if (orientation) {
+ family.orientations.push(orientation)
+ }
+ if (type === 'video' && duration) {
+ family.durations = family.durations || []
+ family.durations.push(duration)
+ }
+
+ familyMap.set(familyKey, family)
+ }
+
+ return Array.from(familyMap.values())
+ .map((family) => ({
+ ...family,
+ orientations:
+ family.orientations.length > 0
+ ? dedupeStrings(family.orientations)
+ : (family.type === 'image' ? ['square'] : ['landscape']),
+ durations:
+ family.type === 'video'
+ ? Array.from(new Set((family.durations || []).filter((d): d is number => Number.isFinite(d)))).sort((a, b) => a - b)
+ : []
+ }))
+ .filter((family) => family.id !== '')
+}
+
+function normalizeModelFamilyRecord(item: unknown): SoraModelFamily | null {
+ const model = asRecord(item)
+ if (!model || typeof model.id !== 'string' || model.id.trim() === '') return null
+ // 仅把明确的“家族结构”识别为 family;老结构(单模型)走 legacy 聚合逻辑。
+ if (!Array.isArray(model.orientations) && !Array.isArray(model.durations)) return null
+
+ const orientations = asArray(model.orientations).filter((o): o is string => typeof o === 'string' && o.length > 0)
+ const durations = asArray(model.durations)
+ .map(asPositiveInt)
+ .filter((d): d is number => d !== null)
+
+ return {
+ id: model.id.trim(),
+ name: typeof model.name === 'string' && model.name.trim() ? model.name.trim() : model.id.trim(),
+ type: model.type === 'image' ? 'image' : 'video',
+ orientations: dedupeStrings(orientations),
+ durations: Array.from(new Set(durations)).sort((a, b) => a - b)
+ }
+}
+
+function extractCandidateArray(payload: unknown): unknown[] {
+ if (Array.isArray(payload)) return payload
+ const record = asRecord(payload)
+ if (!record) return []
+
+ const keys: Array = ['data', 'items', 'models', 'families']
+ for (const key of keys) {
+ if (Array.isArray(record[key])) {
+ return record[key] as unknown[]
+ }
+ }
+ return []
+}
+
+export function normalizeModelFamiliesResponse(payload: unknown): SoraModelFamily[] {
+ const candidates = extractCandidateArray(payload)
+ if (candidates.length === 0) return []
+
+ const normalized = candidates
+ .map(normalizeModelFamilyRecord)
+ .filter((item): item is SoraModelFamily => item !== null)
+
+ if (normalized.length > 0) return normalized
+ return normalizeLegacyFamilies(candidates)
+}
+
+export function normalizeGenerationListResponse(payload: unknown): GenerationListResponse {
+ const record = asRecord(payload)
+ if (!record) {
+ return { data: [], total: 0, page: 1 }
+ }
+
+ const data = Array.isArray(record.data)
+ ? (record.data as SoraGeneration[])
+ : Array.isArray(record.items)
+ ? (record.items as SoraGeneration[])
+ : []
+
+ const total = Number(record.total)
+ const page = Number(record.page)
+
+ return {
+ data,
+ total: Number.isFinite(total) ? total : data.length,
+ page: Number.isFinite(page) && page > 0 ? page : 1
+ }
+}
+
+// ==================== API 方法 ====================
+
+/** 异步生成 — 创建 pending 记录后立即返回 */
+export async function generate(req: GenerateRequest): Promise {
+ const { data } = await apiClient.post('/sora/generate', req)
+ return data
+}
+
+/** 查询生成记录列表 */
+export async function listGenerations(params?: {
+ page?: number
+ page_size?: number
+ status?: string
+ storage_type?: string
+ media_type?: string
+}): Promise {
+ const { data } = await apiClient.get('/sora/generations', { params })
+ return normalizeGenerationListResponse(data)
+}
+
+/** 查询生成记录详情 */
+export async function getGeneration(id: number): Promise {
+ const { data } = await apiClient.get(`/sora/generations/${id}`)
+ return data
+}
+
+/** 删除生成记录 */
+export async function deleteGeneration(id: number): Promise<{ message: string }> {
+ const { data } = await apiClient.delete<{ message: string }>(`/sora/generations/${id}`)
+ return data
+}
+
+/** 取消生成任务 */
+export async function cancelGeneration(id: number): Promise<{ message: string }> {
+ const { data } = await apiClient.post<{ message: string }>(`/sora/generations/${id}/cancel`)
+ return data
+}
+
+/** 手动保存到 S3 */
+export async function saveToStorage(
+ id: number
+): Promise<{ message: string; object_key: string; object_keys?: string[] }> {
+ const { data } = await apiClient.post<{ message: string; object_key: string; object_keys?: string[] }>(
+ `/sora/generations/${id}/save`
+ )
+ return data
+}
+
+/** 查询配额信息 */
+export async function getQuota(): Promise {
+ const { data } = await apiClient.get('/sora/quota')
+ return data
+}
+
+/** 获取可用模型家族列表 */
+export async function getModels(): Promise {
+ const { data } = await apiClient.get('/sora/models')
+ return normalizeModelFamiliesResponse(data)
+}
+
+/** 获取存储状态 */
+export async function getStorageStatus(): Promise {
+ const { data } = await apiClient.get('/sora/storage-status')
+ return data
+}
+
+const soraAPI = {
+ generate,
+ listGenerations,
+ getGeneration,
+ deleteGeneration,
+ cancelGeneration,
+ saveToStorage,
+ getQuota,
+ getModels,
+ getStorageStatus
+}
+
+export default soraAPI
diff --git a/frontend/src/components/account/AccountCapacityCell.vue b/frontend/src/components/account/AccountCapacityCell.vue
index ae338aca..2a4babf2 100644
--- a/frontend/src/components/account/AccountCapacityCell.vue
+++ b/frontend/src/components/account/AccountCapacityCell.vue
@@ -52,6 +52,25 @@
{{ account.max_sessions }}
+
+
+
+
+
+ {{ currentRPM }}
+ /
+ {{ account.base_rpm }}
+ {{ rpmStrategyTag }}
+
+
@@ -125,19 +144,15 @@ const windowCostClass = computed(() => {
const limit = props.account.window_cost_limit || 0
const reserve = props.account.window_cost_sticky_reserve || 10
- // >= 阈值+预留: 完全不可调度 (红色)
if (current >= limit + reserve) {
return 'bg-red-100 text-red-700 dark:bg-red-900/30 dark:text-red-400'
}
- // >= 阈值: 仅粘性会话 (橙色)
if (current >= limit) {
return 'bg-orange-100 text-orange-700 dark:bg-orange-900/30 dark:text-orange-400'
}
- // >= 80% 阈值: 警告 (黄色)
if (current >= limit * 0.8) {
return 'bg-yellow-100 text-yellow-700 dark:bg-yellow-900/30 dark:text-yellow-400'
}
- // 正常 (绿色)
return 'bg-emerald-100 text-emerald-700 dark:bg-emerald-900/30 dark:text-emerald-400'
})
@@ -165,15 +180,12 @@ const sessionLimitClass = computed(() => {
const current = activeSessions.value
const max = props.account.max_sessions || 0
- // >= 最大: 完全占满 (红色)
if (current >= max) {
return 'bg-red-100 text-red-700 dark:bg-red-900/30 dark:text-red-400'
}
- // >= 80%: 警告 (黄色)
if (current >= max * 0.8) {
return 'bg-yellow-100 text-yellow-700 dark:bg-yellow-900/30 dark:text-yellow-400'
}
- // 正常 (绿色)
return 'bg-emerald-100 text-emerald-700 dark:bg-emerald-900/30 dark:text-emerald-400'
})
@@ -191,6 +203,89 @@ const sessionLimitTooltip = computed(() => {
return t('admin.accounts.capacity.sessions.normal', { idle })
})
+// 是否显示 RPM 限制
+const showRpmLimit = computed(() => {
+ return (
+ isAnthropicOAuthOrSetupToken.value &&
+ props.account.base_rpm !== undefined &&
+ props.account.base_rpm !== null &&
+ props.account.base_rpm > 0
+ )
+})
+
+// 当前 RPM 计数
+const currentRPM = computed(() => props.account.current_rpm ?? 0)
+
+// RPM 策略
+const rpmStrategy = computed(() => props.account.rpm_strategy || 'tiered')
+
+// RPM 策略标签
+const rpmStrategyTag = computed(() => {
+ return rpmStrategy.value === 'sticky_exempt' ? '[S]' : '[T]'
+})
+
+// RPM buffer 计算(与后端一致:base <= 0 时 buffer 为 0)
+const rpmBuffer = computed(() => {
+ const base = props.account.base_rpm || 0
+ return props.account.rpm_sticky_buffer ?? (base > 0 ? Math.max(1, Math.floor(base / 5)) : 0)
+})
+
+// RPM 状态样式
+const rpmClass = computed(() => {
+ if (!showRpmLimit.value) return ''
+
+ const current = currentRPM.value
+ const base = props.account.base_rpm ?? 0
+ const buffer = rpmBuffer.value
+
+ if (rpmStrategy.value === 'tiered') {
+ if (current >= base + buffer) {
+ return 'bg-red-100 text-red-700 dark:bg-red-900/30 dark:text-red-400'
+ }
+ if (current >= base) {
+ return 'bg-orange-100 text-orange-700 dark:bg-orange-900/30 dark:text-orange-400'
+ }
+ } else {
+ if (current >= base) {
+ return 'bg-orange-100 text-orange-700 dark:bg-orange-900/30 dark:text-orange-400'
+ }
+ }
+ if (current >= base * 0.8) {
+ return 'bg-yellow-100 text-yellow-700 dark:bg-yellow-900/30 dark:text-yellow-400'
+ }
+ return 'bg-emerald-100 text-emerald-700 dark:bg-emerald-900/30 dark:text-emerald-400'
+})
+
+// RPM 提示文字(增强版:显示策略、区域、缓冲区)
+const rpmTooltip = computed(() => {
+ if (!showRpmLimit.value) return ''
+
+ const current = currentRPM.value
+ const base = props.account.base_rpm ?? 0
+ const buffer = rpmBuffer.value
+
+ if (rpmStrategy.value === 'tiered') {
+ if (current >= base + buffer) {
+ return t('admin.accounts.capacity.rpm.tieredBlocked', { buffer })
+ }
+ if (current >= base) {
+ return t('admin.accounts.capacity.rpm.tieredStickyOnly', { buffer })
+ }
+ if (current >= base * 0.8) {
+ return t('admin.accounts.capacity.rpm.tieredWarning')
+ }
+ return t('admin.accounts.capacity.rpm.tieredNormal')
+ } else {
+ if (current >= base) {
+ return t('admin.accounts.capacity.rpm.stickyExemptOver')
+ }
+ if (current >= base * 0.8) {
+ return t('admin.accounts.capacity.rpm.stickyExemptWarning')
+ }
+ return t('admin.accounts.capacity.rpm.stickyExemptNormal')
+ }
+})
+
// 格式化费用显示
const formatCost = (value: number | null | undefined) => {
if (value === null || value === undefined) return '0'
diff --git a/frontend/src/components/account/AccountStatusIndicator.vue b/frontend/src/components/account/AccountStatusIndicator.vue
index 8816eb26..e8331c25 100644
--- a/frontend/src/components/account/AccountStatusIndicator.vue
+++ b/frontend/src/components/account/AccountStatusIndicator.vue
@@ -166,7 +166,8 @@ const activeModelRateLimits = computed(() => {
const formatScopeName = (scope: string): string => {
const aliases: Record = {
// Claude 系列
- 'claude-opus-4-6-thinking': 'COpus46',
+ 'claude-opus-4-6': 'COpus46',
+ 'claude-opus-4-6-thinking': 'COpus46T',
'claude-sonnet-4-6': 'CSon46',
'claude-sonnet-4-5': 'CSon45',
'claude-sonnet-4-5-thinking': 'CSon45T',
@@ -180,6 +181,7 @@ const formatScopeName = (scope: string): string => {
'gemini-3.1-pro-high': 'G3PH',
'gemini-3.1-pro-low': 'G3PL',
'gemini-3-pro-image': 'G3PI',
+ 'gemini-3.1-flash-image': 'GImage',
// 其他
'gpt-oss-120b-medium': 'GPT120',
'tab_flash_lite_preview': 'TabFL',
diff --git a/frontend/src/components/account/AccountTodayStatsCell.vue b/frontend/src/components/account/AccountTodayStatsCell.vue
index a920f314..a422d1f0 100644
--- a/frontend/src/components/account/AccountTodayStatsCell.vue
+++ b/frontend/src/components/account/AccountTodayStatsCell.vue
@@ -1,26 +1,26 @@
-
+
-
- {{ error }}
+
+ {{ props.error }}
-
+
{{ t('admin.accounts.stats.requests') }}:
{{
- formatNumber(stats.requests)
+ formatNumber(props.stats.requests)
}}
@@ -29,21 +29,21 @@
>{{ t('admin.accounts.stats.tokens') }}:
{{
- formatTokens(stats.tokens)
+ formatTokens(props.stats.tokens)
}}
{{ t('usage.accountBilled') }}:
{{
- formatCurrency(stats.cost)
+ formatCurrency(props.stats.cost)
}}
-
+
{{ t('usage.userBilled') }}:
{{
- formatCurrency(stats.user_cost)
+ formatCurrency(props.stats.user_cost)
}}
@@ -54,22 +54,25 @@
diff --git a/frontend/src/components/account/AccountUsageCell.vue b/frontend/src/components/account/AccountUsageCell.vue
index b47b4115..859bd7c9 100644
--- a/frontend/src/components/account/AccountUsageCell.vue
+++ b/frontend/src/components/account/AccountUsageCell.vue
@@ -397,14 +397,16 @@ const antigravity3ProUsageFromAPI = computed(() =>
// Gemini 3 Flash from API
const antigravity3FlashUsageFromAPI = computed(() => getAntigravityUsageFromAPI(['gemini-3-flash']))
-// Gemini 3 Image from API
-const antigravity3ImageUsageFromAPI = computed(() => getAntigravityUsageFromAPI(['gemini-3-pro-image']))
+// Gemini Image from API
+const antigravity3ImageUsageFromAPI = computed(() =>
+ getAntigravityUsageFromAPI(['gemini-3.1-flash-image', 'gemini-3-pro-image'])
+)
// Claude from API (all Claude model variants)
const antigravityClaudeUsageFromAPI = computed(() =>
getAntigravityUsageFromAPI([
'claude-sonnet-4-5', 'claude-opus-4-5-thinking',
- 'claude-sonnet-4-6', 'claude-opus-4-6-thinking',
+ 'claude-sonnet-4-6', 'claude-opus-4-6', 'claude-opus-4-6-thinking',
])
)
diff --git a/frontend/src/components/account/BulkEditAccountModal.vue b/frontend/src/components/account/BulkEditAccountModal.vue
index 1c4395ec..30c3d739 100644
--- a/frontend/src/components/account/BulkEditAccountModal.vue
+++ b/frontend/src/components/account/BulkEditAccountModal.vue
@@ -21,6 +21,16 @@
+
+
+
+
+ {{ t('admin.accounts.bulkEdit.mixedPlatformWarning', { platforms: selectedPlatforms.join(', ') }) }}
+
+
+
@@ -157,7 +167,7 @@
+
+
+
+
+
+
+
+
+
+ {{ t('admin.accounts.quotaControl.rpmLimit.hint') }}
+
+
+
+
+
+
+
+
{{ t('admin.accounts.quotaControl.rpmLimit.baseRpmHint') }}
+
+
+
+
+
+
+
+
+
+
+
+
+
+
{{ t('admin.accounts.quotaControl.rpmLimit.stickyBufferHint') }}
+
+
+
+
+
@@ -641,6 +756,17 @@
+
+
diff --git a/frontend/src/components/admin/user/UserEditModal.vue b/frontend/src/components/admin/user/UserEditModal.vue
index 70ebd2d3..e537dbf6 100644
--- a/frontend/src/components/admin/user/UserEditModal.vue
+++ b/frontend/src/components/admin/user/UserEditModal.vue
@@ -37,6 +37,14 @@
+
+
+
+
+ GB
+
+
{{ t('admin.users.soraStorageQuotaHint') }}
+
@@ -66,11 +74,11 @@ const emit = defineEmits(['close', 'success'])
const { t } = useI18n(); const appStore = useAppStore(); const { copyToClipboard } = useClipboard()
const submitting = ref(false); const passwordCopied = ref(false)
-const form = reactive({ email: '', password: '', username: '', notes: '', concurrency: 1, customAttributes: {} as UserAttributeValuesMap })
+const form = reactive({ email: '', password: '', username: '', notes: '', concurrency: 1, sora_storage_quota_gb: 0, customAttributes: {} as UserAttributeValuesMap })
watch(() => props.user, (u) => {
if (u) {
- Object.assign(form, { email: u.email, password: '', username: u.username || '', notes: u.notes || '', concurrency: u.concurrency, customAttributes: {} })
+ Object.assign(form, { email: u.email, password: '', username: u.username || '', notes: u.notes || '', concurrency: u.concurrency, sora_storage_quota_gb: Number(((u.sora_storage_quota_bytes || 0) / (1024 * 1024 * 1024)).toFixed(2)), customAttributes: {} })
passwordCopied.value = false
}
}, { immediate: true })
@@ -97,7 +105,7 @@ const handleUpdateUser = async () => {
}
submitting.value = true
try {
- const data: any = { email: form.email, username: form.username, notes: form.notes, concurrency: form.concurrency }
+ const data: any = { email: form.email, username: form.username, notes: form.notes, concurrency: form.concurrency, sora_storage_quota_bytes: Math.round((form.sora_storage_quota_gb || 0) * 1024 * 1024 * 1024) }
if (form.password.trim()) data.password = form.password.trim()
await adminAPI.users.update(props.user.id, data)
if (Object.keys(form.customAttributes).length > 0) await adminAPI.userAttributes.updateUserAttributeValues(props.user.id, form.customAttributes)
diff --git a/frontend/src/components/charts/GroupDistributionChart.vue b/frontend/src/components/charts/GroupDistributionChart.vue
new file mode 100644
index 00000000..d9231a63
--- /dev/null
+++ b/frontend/src/components/charts/GroupDistributionChart.vue
@@ -0,0 +1,152 @@
+
+
+
+ {{ t('admin.dashboard.groupDistribution') }}
+
+
+
+
+
+
+
+
+
+
+
+
+ | {{ t('admin.dashboard.group') }} |
+ {{ t('admin.dashboard.requests') }} |
+ {{ t('admin.dashboard.tokens') }} |
+ {{ t('admin.dashboard.actual') }} |
+ {{ t('admin.dashboard.standard') }} |
+
+
+
+
+ |
+ {{ group.group_name || t('admin.dashboard.noGroup') }}
+ |
+
+ {{ formatNumber(group.requests) }}
+ |
+
+ {{ formatTokens(group.total_tokens) }}
+ |
+
+ ${{ formatCost(group.actual_cost) }}
+ |
+
+ ${{ formatCost(group.cost) }}
+ |
+
+
+
+
+
+
+ {{ t('admin.dashboard.noDataAvailable') }}
+
+
+
+
+
diff --git a/frontend/src/components/keys/UseKeyModal.vue b/frontend/src/components/keys/UseKeyModal.vue
index 7b0cbc68..4dd7ff0c 100644
--- a/frontend/src/components/keys/UseKeyModal.vue
+++ b/frontend/src/components/keys/UseKeyModal.vue
@@ -268,6 +268,7 @@ const clientTabs = computed((): TabConfig[] => {
case 'openai':
return [
{ id: 'codex', label: t('keys.useKeyModal.cliTabs.codexCli'), icon: TerminalIcon },
+ { id: 'codex-ws', label: t('keys.useKeyModal.cliTabs.codexCliWs'), icon: TerminalIcon },
{ id: 'opencode', label: t('keys.useKeyModal.cliTabs.opencode'), icon: TerminalIcon }
]
case 'gemini':
@@ -306,7 +307,7 @@ const showShellTabs = computed(() => activeClientTab.value !== 'opencode')
const currentTabs = computed(() => {
if (!showShellTabs.value) return []
- if (props.platform === 'openai') {
+ if (activeClientTab.value === 'codex' || activeClientTab.value === 'codex-ws') {
return openaiTabs
}
return shellTabs
@@ -401,6 +402,9 @@ const currentFiles = computed((): FileConfig[] => {
switch (props.platform) {
case 'openai':
+ if (activeClientTab.value === 'codex-ws') {
+ return generateOpenAIWsFiles(baseUrl, apiKey)
+ }
return generateOpenAIFiles(baseUrl, apiKey)
case 'gemini':
return [generateGeminiCliContent(baseUrl, apiKey)]
@@ -524,6 +528,47 @@ requires_openai_auth = true`
]
}
+function generateOpenAIWsFiles(baseUrl: string, apiKey: string): FileConfig[] {
+ const isWindows = activeTab.value === 'windows'
+ const configDir = isWindows ? '%userprofile%\\.codex' : '~/.codex'
+
+ // config.toml content with WebSocket v2
+ const configContent = `model_provider = "sub2api"
+model = "gpt-5.3-codex"
+model_reasoning_effort = "high"
+network_access = "enabled"
+disable_response_storage = true
+windows_wsl_setup_acknowledged = true
+model_verbosity = "high"
+
+[model_providers.sub2api]
+name = "sub2api"
+base_url = "${baseUrl}"
+wire_api = "responses"
+supports_websockets = true
+requires_openai_auth = true
+
+[features]
+responses_websockets_v2 = true`
+
+ // auth.json content
+ const authContent = `{
+ "OPENAI_API_KEY": "${apiKey}"
+}`
+
+ return [
+ {
+ path: `${configDir}/config.toml`,
+ content: configContent,
+ hint: t('keys.useKeyModal.openai.configTomlHint')
+ },
+ {
+ path: `${configDir}/auth.json`,
+ content: authContent
+ }
+ ]
+}
+
function generateOpenCodeConfig(platform: string, baseUrl: string, apiKey: string, pathLabel?: string): FileConfig {
const provider: Record = {
[platform]: {
@@ -675,11 +720,90 @@ function generateOpenCodeConfig(platform: string, baseUrl: string, apiKey: strin
}
}
const geminiModels = {
- 'gemini-2.0-flash': { name: 'Gemini 2.0 Flash' },
- 'gemini-2.5-flash': { name: 'Gemini 2.5 Flash' },
- 'gemini-2.5-pro': { name: 'Gemini 2.5 Pro' },
- 'gemini-3-flash-preview': { name: 'Gemini 3 Flash Preview' },
- 'gemini-3-pro-preview': { name: 'Gemini 3 Pro Preview' }
+ 'gemini-2.0-flash': {
+ name: 'Gemini 2.0 Flash',
+ limit: {
+ context: 1048576,
+ output: 65536
+ },
+ modalities: {
+ input: ['text', 'image', 'pdf'],
+ output: ['text']
+ }
+ },
+ 'gemini-2.5-flash': {
+ name: 'Gemini 2.5 Flash',
+ limit: {
+ context: 1048576,
+ output: 65536
+ },
+ modalities: {
+ input: ['text', 'image', 'pdf'],
+ output: ['text']
+ }
+ },
+ 'gemini-2.5-pro': {
+ name: 'Gemini 2.5 Pro',
+ limit: {
+ context: 2097152,
+ output: 65536
+ },
+ modalities: {
+ input: ['text', 'image', 'pdf'],
+ output: ['text']
+ },
+ options: {
+ thinking: {
+ budgetTokens: 24576,
+ type: 'enabled'
+ }
+ }
+ },
+ 'gemini-3-flash-preview': {
+ name: 'Gemini 3 Flash Preview',
+ limit: {
+ context: 1048576,
+ output: 65536
+ },
+ modalities: {
+ input: ['text', 'image', 'pdf'],
+ output: ['text']
+ }
+ },
+ 'gemini-3-pro-preview': {
+ name: 'Gemini 3 Pro Preview',
+ limit: {
+ context: 1048576,
+ output: 65536
+ },
+ modalities: {
+ input: ['text', 'image', 'pdf'],
+ output: ['text']
+ },
+ options: {
+ thinking: {
+ budgetTokens: 24576,
+ type: 'enabled'
+ }
+ }
+ },
+ 'gemini-3.1-pro-preview': {
+ name: 'Gemini 3.1 Pro Preview',
+ limit: {
+ context: 1048576,
+ output: 65536
+ },
+ modalities: {
+ input: ['text', 'image', 'pdf'],
+ output: ['text']
+ },
+ options: {
+ thinking: {
+ budgetTokens: 24576,
+ type: 'enabled'
+ }
+ }
+ }
}
const antigravityGeminiModels = {
@@ -785,8 +909,8 @@ function generateOpenCodeConfig(platform: string, baseUrl: string, apiKey: strin
}
}
},
- 'gemini-3-pro-image': {
- name: 'Gemini 3 Pro (Image)',
+ 'gemini-3.1-flash-image': {
+ name: 'Gemini 3.1 Flash Image',
limit: {
context: 1048576,
output: 65536
@@ -804,25 +928,38 @@ function generateOpenCodeConfig(platform: string, baseUrl: string, apiKey: strin
}
}
const claudeModels = {
- 'claude-opus-4-5-thinking': {
- name: 'Claude Opus 4.5 Thinking',
+ 'claude-opus-4-6-thinking': {
+ name: 'Claude 4.6 Opus (Thinking)',
limit: {
context: 200000,
- output: 64000
+ output: 128000
+ },
+ modalities: {
+ input: ['text', 'image', 'pdf'],
+ output: ['text']
+ },
+ options: {
+ thinking: {
+ budgetTokens: 24576,
+ type: 'enabled'
+ }
}
},
- 'claude-sonnet-4-5-thinking': {
- name: 'Claude Sonnet 4.5 Thinking',
- limit: {
- context: 200000,
- output: 64000
- }
- },
- 'claude-sonnet-4-5': {
- name: 'Claude Sonnet 4.5',
+ 'claude-sonnet-4-6': {
+ name: 'Claude 4.6 Sonnet',
limit: {
context: 200000,
output: 64000
+ },
+ modalities: {
+ input: ['text', 'image', 'pdf'],
+ output: ['text']
+ },
+ options: {
+ thinking: {
+ budgetTokens: 24576,
+ type: 'enabled'
+ }
}
}
}
diff --git a/frontend/src/components/layout/AppSidebar.vue b/frontend/src/components/layout/AppSidebar.vue
index e5afde9c..b356e3e5 100644
--- a/frontend/src/components/layout/AppSidebar.vue
+++ b/frontend/src/components/layout/AppSidebar.vue
@@ -290,6 +290,26 @@ const CreditCardIcon = {
)
}
+const RechargeSubscriptionIcon = {
+ render: () =>
+ h(
+ 'svg',
+ { fill: 'none', viewBox: '0 0 24 24', stroke: 'currentColor', 'stroke-width': '1.5' },
+ [
+ h('path', {
+ 'stroke-linecap': 'round',
+ 'stroke-linejoin': 'round',
+ d: 'M2.25 7.5A2.25 2.25 0 014.5 5.25h15A2.25 2.25 0 0121.75 7.5v9A2.25 2.25 0 0119.5 18.75h-15A2.25 2.25 0 012.25 16.5v-9z'
+ }),
+ h('path', {
+ 'stroke-linecap': 'round',
+ 'stroke-linejoin': 'round',
+ d: 'M6.75 12h3m4.5 0h3m-3-3v6'
+ })
+ ]
+ )
+}
+
const GlobeIcon = {
render: () =>
h(
@@ -320,6 +340,36 @@ const ServerIcon = {
)
}
+const DatabaseIcon = {
+ render: () =>
+ h(
+ 'svg',
+ { fill: 'none', viewBox: '0 0 24 24', stroke: 'currentColor', 'stroke-width': '1.5' },
+ [
+ h('path', {
+ 'stroke-linecap': 'round',
+ 'stroke-linejoin': 'round',
+ d: 'M3.75 5.25C3.75 4.007 7.443 3 12 3s8.25 1.007 8.25 2.25S16.557 7.5 12 7.5 3.75 6.493 3.75 5.25z'
+ }),
+ h('path', {
+ 'stroke-linecap': 'round',
+ 'stroke-linejoin': 'round',
+ d: 'M3.75 5.25v4.5C3.75 10.993 7.443 12 12 12s8.25-1.007 8.25-2.25v-4.5'
+ }),
+ h('path', {
+ 'stroke-linecap': 'round',
+ 'stroke-linejoin': 'round',
+ d: 'M3.75 9.75v4.5c0 1.243 3.693 2.25 8.25 2.25s8.25-1.007 8.25-2.25v-4.5'
+ }),
+ h('path', {
+ 'stroke-linecap': 'round',
+ 'stroke-linejoin': 'round',
+ d: 'M3.75 14.25v4.5C3.75 19.993 7.443 21 12 21s8.25-1.007 8.25-2.25v-4.5'
+ })
+ ]
+ )
+}
+
const BellIcon = {
render: () =>
h(
@@ -415,6 +465,21 @@ const ChevronDoubleLeftIcon = {
)
}
+const SoraIcon = {
+ render: () =>
+ h(
+ 'svg',
+ { fill: 'none', viewBox: '0 0 24 24', stroke: 'currentColor', 'stroke-width': '1.5' },
+ [
+ h('path', {
+ 'stroke-linecap': 'round',
+ 'stroke-linejoin': 'round',
+ d: 'M9.813 15.904L9 18.75l-.813-2.846a4.5 4.5 0 00-3.09-3.09L2.25 12l2.846-.813a4.5 4.5 0 003.09-3.09L9 5.25l.813 2.846a4.5 4.5 0 003.09 3.09L15.75 12l-2.846.813a4.5 4.5 0 00-3.09 3.09z'
+ })
+ ]
+ )
+}
+
const ChevronDoubleRightIcon = {
render: () =>
h(
@@ -437,12 +502,15 @@ const userNavItems = computed(() => {
{ path: '/keys', label: t('nav.apiKeys'), icon: KeyIcon },
{ path: '/usage', label: t('nav.usage'), icon: ChartIcon, hideInSimpleMode: true },
{ path: '/subscriptions', label: t('nav.mySubscriptions'), icon: CreditCardIcon, hideInSimpleMode: true },
+ ...(appStore.cachedPublicSettings?.sora_client_enabled
+ ? [{ path: '/sora', label: t('nav.sora'), icon: SoraIcon }]
+ : []),
...(appStore.cachedPublicSettings?.purchase_subscription_enabled
? [
{
path: '/purchase',
label: t('nav.buySubscription'),
- icon: CreditCardIcon,
+ icon: RechargeSubscriptionIcon,
hideInSimpleMode: true
}
]
@@ -459,12 +527,15 @@ const personalNavItems = computed(() => {
{ path: '/keys', label: t('nav.apiKeys'), icon: KeyIcon },
{ path: '/usage', label: t('nav.usage'), icon: ChartIcon, hideInSimpleMode: true },
{ path: '/subscriptions', label: t('nav.mySubscriptions'), icon: CreditCardIcon, hideInSimpleMode: true },
+ ...(appStore.cachedPublicSettings?.sora_client_enabled
+ ? [{ path: '/sora', label: t('nav.sora'), icon: SoraIcon }]
+ : []),
...(appStore.cachedPublicSettings?.purchase_subscription_enabled
? [
{
path: '/purchase',
label: t('nav.buySubscription'),
- icon: CreditCardIcon,
+ icon: RechargeSubscriptionIcon,
hideInSimpleMode: true
}
]
@@ -490,17 +561,19 @@ const adminNavItems = computed(() => {
{ path: '/admin/proxies', label: t('nav.proxies'), icon: ServerIcon },
{ path: '/admin/redeem', label: t('nav.redeemCodes'), icon: TicketIcon, hideInSimpleMode: true },
{ path: '/admin/promo-codes', label: t('nav.promoCodes'), icon: GiftIcon, hideInSimpleMode: true },
- { path: '/admin/usage', label: t('nav.usage'), icon: ChartIcon },
+ { path: '/admin/usage', label: t('nav.usage'), icon: ChartIcon }
]
// 简单模式下,在系统设置前插入 API密钥
if (authStore.isSimpleMode) {
const filtered = baseItems.filter(item => !item.hideInSimpleMode)
filtered.push({ path: '/keys', label: t('nav.apiKeys'), icon: KeyIcon })
+ filtered.push({ path: '/admin/data-management', label: t('nav.dataManagement'), icon: DatabaseIcon })
filtered.push({ path: '/admin/settings', label: t('nav.settings'), icon: CogIcon })
return filtered
}
+ baseItems.push({ path: '/admin/data-management', label: t('nav.dataManagement'), icon: DatabaseIcon })
baseItems.push({ path: '/admin/settings', label: t('nav.settings'), icon: CogIcon })
return baseItems
})
diff --git a/frontend/src/components/sora/SoraDownloadDialog.vue b/frontend/src/components/sora/SoraDownloadDialog.vue
new file mode 100644
index 00000000..5f39980f
--- /dev/null
+++ b/frontend/src/components/sora/SoraDownloadDialog.vue
@@ -0,0 +1,217 @@
+
+
+
+
+
+
+
📥
+
{{ t('sora.downloadTitle') }}
+
{{ t('sora.downloadExpirationWarning') }}
+
+
+
+
+
+ {{ isExpired ? t('sora.upstreamExpired') : t('sora.upstreamCountdown', { time: remainingText }) }}
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/frontend/src/components/sora/SoraGeneratePage.vue b/frontend/src/components/sora/SoraGeneratePage.vue
new file mode 100644
index 00000000..1f77edc4
--- /dev/null
+++ b/frontend/src/components/sora/SoraGeneratePage.vue
@@ -0,0 +1,430 @@
+
+
+
+
+
+
{{ t('sora.welcomeTitle') }}
+
{{ t('sora.welcomeSubtitle') }}
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ ⚠️
+ {{ t('sora.noStorageToastMessage') }}
+
+
+
+
+
+
+
+
+
+
+
diff --git a/frontend/src/components/sora/SoraLibraryPage.vue b/frontend/src/components/sora/SoraLibraryPage.vue
new file mode 100644
index 00000000..0e2b5e1d
--- /dev/null
+++ b/frontend/src/components/sora/SoraLibraryPage.vue
@@ -0,0 +1,576 @@
+
+
+
+
+
+
+
+
+ {{ t('sora.galleryCount', { count: filteredItems.length }) }}
+
+
+
+
+
+
+
+
+
+
![]()
+
+ {{ item.media_type === 'video' ? '🎬' : '🎨' }}
+
+
+
+
+ {{ item.media_type === 'video' ? 'VIDEO' : 'IMAGE' }}
+
+
+
+
+
+
+
+
+
+
▶
+
+
+
+ {{ formatDuration(item) }}
+
+
+
+
+
+
{{ item.model }}
+
{{ formatTime(item.created_at) }}
+
+
+
+
+
+
+
🎬
+
{{ t('sora.galleryEmptyTitle') }}
+
{{ t('sora.galleryEmptyDesc') }}
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/frontend/src/components/sora/SoraMediaPreview.vue b/frontend/src/components/sora/SoraMediaPreview.vue
new file mode 100644
index 00000000..09a3aea1
--- /dev/null
+++ b/frontend/src/components/sora/SoraMediaPreview.vue
@@ -0,0 +1,282 @@
+
+
+
+
+
+
+
+
+
+
+
diff --git a/frontend/src/components/sora/SoraNoStorageWarning.vue b/frontend/src/components/sora/SoraNoStorageWarning.vue
new file mode 100644
index 00000000..c5ede271
--- /dev/null
+++ b/frontend/src/components/sora/SoraNoStorageWarning.vue
@@ -0,0 +1,39 @@
+
+
+
⚠️
+
+
{{ t('sora.noStorageWarningTitle') }}
+
{{ t('sora.noStorageWarningDesc') }}
+
+
+
+
+
+
+
diff --git a/frontend/src/components/sora/SoraProgressCard.vue b/frontend/src/components/sora/SoraProgressCard.vue
new file mode 100644
index 00000000..69b28ef9
--- /dev/null
+++ b/frontend/src/components/sora/SoraProgressCard.vue
@@ -0,0 +1,609 @@
+
+
+
+
+
+
+
+ {{ generation.prompt }}
+
+
+
+
+ ⛔ {{ t('sora.errorCategory') }}
+
+
+ {{ generation.error_message }}
+
+
+
+
+
+
+ {{ progressInfoText }}
+ {{ progressInfoRight }}
+
+
+
+
+
+
+
![]()
+
+
+
+
+
+
+
+
+
+
+
+ ✓ {{ t('sora.savedToCloud') }}
+
+
+
+
+
+ 📥 {{ t('sora.downloadLocal') }}
+
+
+
+ ⏱ {{ t('sora.upstreamCountdown', { time: countdownText }) }} {{ t('sora.canDownload') }}
+
+
+ {{ t('sora.upstreamExpired') }}
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/frontend/src/components/sora/SoraPromptBar.vue b/frontend/src/components/sora/SoraPromptBar.vue
new file mode 100644
index 00000000..f5f1bfc9
--- /dev/null
+++ b/frontend/src/components/sora/SoraPromptBar.vue
@@ -0,0 +1,738 @@
+
+
+
+
+
+
+
+
+ ▼
+
+
+
+
+ ▼
+
+
+
+ ⚠ {{ t('sora.noCredentialHint') }}
+
+
+
+ ⚠ {{ t('sora.noStorageConfigured') }}
+
+
+
+
+
+
+
![]()
+
+
+
{{ t('sora.referenceImage') }}
+
+
+
+
+
+
+
+
+
+
+
+
+
+
{{ imageError }}
+
+
+
+
+
+
diff --git a/frontend/src/components/sora/SoraQuotaBar.vue b/frontend/src/components/sora/SoraQuotaBar.vue
new file mode 100644
index 00000000..4a3af027
--- /dev/null
+++ b/frontend/src/components/sora/SoraQuotaBar.vue
@@ -0,0 +1,87 @@
+
+
+
+
+ {{ formatBytes(quota.used_bytes) }} / {{ quota.quota_bytes === 0 ? '∞' : formatBytes(quota.quota_bytes) }}
+
+
+
+
+
+
+
diff --git a/frontend/src/composables/__tests__/useModelWhitelist.spec.ts b/frontend/src/composables/__tests__/useModelWhitelist.spec.ts
new file mode 100644
index 00000000..4088e5a4
--- /dev/null
+++ b/frontend/src/composables/__tests__/useModelWhitelist.spec.ts
@@ -0,0 +1,18 @@
+import { describe, expect, it } from 'vitest'
+import { buildModelMappingObject, getModelsByPlatform } from '../useModelWhitelist'
+
+describe('useModelWhitelist', () => {
+ it('antigravity 模型列表包含图片模型兼容项', () => {
+ const models = getModelsByPlatform('antigravity')
+
+ expect(models).toContain('gemini-3.1-flash-image')
+ expect(models).toContain('gemini-3-pro-image')
+ })
+
+ it('whitelist 模式会忽略通配符条目', () => {
+ const mapping = buildModelMappingObject('whitelist', ['claude-*', 'gemini-3.1-flash-image'], [])
+ expect(mapping).toEqual({
+ 'gemini-3.1-flash-image': 'gemini-3.1-flash-image'
+ })
+ })
+})
diff --git a/frontend/src/composables/__tests__/useOpenAIOAuth.spec.ts b/frontend/src/composables/__tests__/useOpenAIOAuth.spec.ts
new file mode 100644
index 00000000..ee3f7990
--- /dev/null
+++ b/frontend/src/composables/__tests__/useOpenAIOAuth.spec.ts
@@ -0,0 +1,49 @@
+import { describe, expect, it, vi } from 'vitest'
+
+vi.mock('@/stores/app', () => ({
+ useAppStore: () => ({
+ showError: vi.fn()
+ })
+}))
+
+vi.mock('@/api/admin', () => ({
+ adminAPI: {
+ accounts: {
+ generateAuthUrl: vi.fn(),
+ exchangeCode: vi.fn(),
+ refreshOpenAIToken: vi.fn(),
+ validateSoraSessionToken: vi.fn()
+ }
+ }
+}))
+
+import { useOpenAIOAuth } from '@/composables/useOpenAIOAuth'
+
+describe('useOpenAIOAuth.buildCredentials', () => {
+ it('should keep client_id when token response contains it', () => {
+ const oauth = useOpenAIOAuth({ platform: 'sora' })
+ const creds = oauth.buildCredentials({
+ access_token: 'at',
+ refresh_token: 'rt',
+ client_id: 'app_sora_client',
+ expires_at: 1700000000
+ })
+
+ expect(creds.client_id).toBe('app_sora_client')
+ expect(creds.access_token).toBe('at')
+ expect(creds.refresh_token).toBe('rt')
+ })
+
+ it('should keep legacy behavior when client_id is missing', () => {
+ const oauth = useOpenAIOAuth({ platform: 'openai' })
+ const creds = oauth.buildCredentials({
+ access_token: 'at',
+ refresh_token: 'rt',
+ expires_at: 1700000000
+ })
+
+ expect(Object.prototype.hasOwnProperty.call(creds, 'client_id')).toBe(false)
+ expect(creds.access_token).toBe('at')
+ expect(creds.refresh_token).toBe('rt')
+ })
+})
diff --git a/frontend/src/composables/useModelWhitelist.ts b/frontend/src/composables/useModelWhitelist.ts
index ddc5661b..444e4b91 100644
--- a/frontend/src/composables/useModelWhitelist.ts
+++ b/frontend/src/composables/useModelWhitelist.ts
@@ -24,6 +24,8 @@ const openaiModels = [
// GPT-5.2 系列
'gpt-5.2', 'gpt-5.2-2025-12-11', 'gpt-5.2-chat-latest',
'gpt-5.2-codex', 'gpt-5.2-pro', 'gpt-5.2-pro-2025-12-11',
+ // GPT-5.3 系列
+ 'gpt-5.3-codex', 'gpt-5.3-codex-spark',
'chatgpt-4o-latest',
'gpt-4o-audio-preview', 'gpt-4o-realtime-preview'
]
@@ -75,6 +77,7 @@ const soraModels = [
const antigravityModels = [
// Claude 4.5+ 系列
'claude-opus-4-6',
+ 'claude-opus-4-6-thinking',
'claude-opus-4-5-thinking',
'claude-sonnet-4-6',
'claude-sonnet-4-5',
@@ -88,10 +91,11 @@ const antigravityModels = [
'gemini-3-flash',
'gemini-3-pro-high',
'gemini-3-pro-low',
- 'gemini-3-pro-image',
// Gemini 3.1 系列
'gemini-3.1-pro-high',
'gemini-3.1-pro-low',
+ 'gemini-3.1-flash-image',
+ 'gemini-3-pro-image',
// 其他
'gpt-oss-120b-medium',
'tab_flash_lite_preview'
@@ -309,6 +313,7 @@ const antigravityPresetMappings = [
// 精确映射
{ label: 'Sonnet 4.6', from: 'claude-sonnet-4-6', to: 'claude-sonnet-4-6', color: 'bg-cyan-100 text-cyan-700 hover:bg-cyan-200 dark:bg-cyan-900/30 dark:text-cyan-400' },
{ label: 'Sonnet 4.5', from: 'claude-sonnet-4-5', to: 'claude-sonnet-4-5', color: 'bg-cyan-100 text-cyan-700 hover:bg-cyan-200 dark:bg-cyan-900/30 dark:text-cyan-400' },
+ { label: 'Opus 4.6', from: 'claude-opus-4-6', to: 'claude-opus-4-6-thinking', color: 'bg-pink-100 text-pink-700 hover:bg-pink-200 dark:bg-pink-900/30 dark:text-pink-400' },
{ label: 'Opus 4.6-thinking', from: 'claude-opus-4-6-thinking', to: 'claude-opus-4-6-thinking', color: 'bg-pink-100 text-pink-700 hover:bg-pink-200 dark:bg-pink-900/30 dark:text-pink-400' }
]
diff --git a/frontend/src/composables/useOpenAIOAuth.ts b/frontend/src/composables/useOpenAIOAuth.ts
index 32045cbe..0f777a38 100644
--- a/frontend/src/composables/useOpenAIOAuth.ts
+++ b/frontend/src/composables/useOpenAIOAuth.ts
@@ -5,6 +5,7 @@ import { adminAPI } from '@/api/admin'
export interface OpenAITokenInfo {
access_token?: string
refresh_token?: string
+ client_id?: string
id_token?: string
token_type?: string
expires_in?: number
@@ -192,6 +193,10 @@ export function useOpenAIOAuth(options?: UseOpenAIOAuthOptions) {
scope: tokenInfo.scope
}
+ if (tokenInfo.client_id) {
+ creds.client_id = tokenInfo.client_id
+ }
+
// Include OpenAI specific IDs (required for forwarding)
if (tokenInfo.chatgpt_account_id) {
creds.chatgpt_account_id = tokenInfo.chatgpt_account_id
diff --git a/frontend/src/i18n/index.ts b/frontend/src/i18n/index.ts
index 00e34dc2..5dab65e8 100644
--- a/frontend/src/i18n/index.ts
+++ b/frontend/src/i18n/index.ts
@@ -68,6 +68,14 @@ export async function setLocale(locale: string): Promise {
i18n.global.locale.value = locale
localStorage.setItem(LOCALE_KEY, locale)
document.documentElement.setAttribute('lang', locale)
+
+ // 同步更新浏览器页签标题,使其跟随语言切换
+ const { resolveDocumentTitle } = await import('@/router/title')
+ const { default: router } = await import('@/router')
+ const { useAppStore } = await import('@/stores/app')
+ const route = router.currentRoute.value
+ const appStore = useAppStore()
+ document.title = resolveDocumentTitle(route.meta.title, appStore.siteName, route.meta.titleKey as string)
}
export function getLocale(): LocaleCode {
diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts
index b1a03789..0726f116 100644
--- a/frontend/src/i18n/locales/en.ts
+++ b/frontend/src/i18n/locales/en.ts
@@ -270,6 +270,7 @@ export default {
redeemCodes: 'Redeem Codes',
ops: 'Ops',
promoCodes: 'Promo Codes',
+ dataManagement: 'Data Management',
settings: 'Settings',
myAccount: 'My Account',
lightMode: 'Light Mode',
@@ -279,8 +280,9 @@ export default {
logout: 'Logout',
github: 'GitHub',
mySubscriptions: 'My Subscriptions',
- buySubscription: 'Purchase Subscription',
- docs: 'Docs'
+ buySubscription: 'Recharge / Subscription',
+ docs: 'Docs',
+ sora: 'Sora Studio'
},
// Auth
@@ -406,9 +408,12 @@ export default {
day: 'Day',
hour: 'Hour',
modelDistribution: 'Model Distribution',
+ groupDistribution: 'Group Usage Distribution',
tokenUsageTrend: 'Token Usage Trend',
noDataAvailable: 'No data available',
model: 'Model',
+ group: 'Group',
+ noGroup: 'No Group',
requests: 'Requests',
tokens: 'Tokens',
actual: 'Actual',
@@ -499,6 +504,7 @@ export default {
claudeCode: 'Claude Code',
geminiCli: 'Gemini CLI',
codexCli: 'Codex CLI',
+ codexCliWs: 'Codex CLI (WebSocket)',
opencode: 'OpenCode',
},
antigravity: {
@@ -612,8 +618,10 @@ export default {
firstToken: 'First Token',
duration: 'Duration',
time: 'Time',
+ ws: 'WS',
stream: 'Stream',
sync: 'Sync',
+ unknown: 'Unknown',
in: 'In',
out: 'Out',
cacheRead: 'Read',
@@ -827,9 +835,12 @@ export default {
day: 'Day',
hour: 'Hour',
modelDistribution: 'Model Distribution',
+ groupDistribution: 'Group Usage Distribution',
tokenUsageTrend: 'Token Usage Trend',
userUsageTrend: 'User Usage Trend (Top 12)',
model: 'Model',
+ group: 'Group',
+ noGroup: 'No Group',
requests: 'Requests',
tokens: 'Tokens',
actual: 'Actual',
@@ -839,6 +850,181 @@ export default {
failedToLoad: 'Failed to load dashboard statistics'
},
+ dataManagement: {
+ title: 'Data Management',
+ description: 'Manage data management agent status, object storage settings, and backup jobs in one place',
+ agent: {
+ title: 'Data Management Agent Status',
+ description: 'The system probes a fixed Unix socket and enables data management only when reachable.',
+ enabled: 'Data management agent is ready. Data management operations are available.',
+ disabled: 'Data management agent is unavailable. Only diagnostic information is available now.',
+ socketPath: 'Socket Path',
+ version: 'Version',
+ status: 'Status',
+ uptime: 'Uptime',
+ reasonLabel: 'Unavailable Reason',
+ reason: {
+ DATA_MANAGEMENT_AGENT_SOCKET_MISSING: 'Data management socket file is missing',
+ DATA_MANAGEMENT_AGENT_UNAVAILABLE: 'Data management agent is unreachable',
+ BACKUP_AGENT_SOCKET_MISSING: 'Backup socket file is missing',
+ BACKUP_AGENT_UNAVAILABLE: 'Backup agent is unreachable',
+ UNKNOWN: 'Unknown reason'
+ }
+ },
+ sections: {
+ config: {
+ title: 'Backup Configuration',
+ description: 'Configure backup source, retention policy, and S3 settings.'
+ },
+ s3: {
+ title: 'S3 Object Storage',
+ description: 'Configure and test uploads of backup artifacts to a standard S3-compatible storage.'
+ },
+ backup: {
+ title: 'Backup Operations',
+ description: 'Trigger PostgreSQL, Redis, and full backup jobs.'
+ },
+ history: {
+ title: 'Backup History',
+ description: 'Review backup job status, errors, and artifact metadata.'
+ }
+ },
+ form: {
+ sourceMode: 'Source Mode',
+ backupRoot: 'Backup Root',
+ activePostgresProfile: 'Active PostgreSQL Profile',
+ activeRedisProfile: 'Active Redis Profile',
+ activeS3Profile: 'Active S3 Profile',
+ retentionDays: 'Retention Days',
+ keepLast: 'Keep Last Jobs',
+ uploadToS3: 'Upload to S3',
+ useActivePostgresProfile: 'Use Active PostgreSQL Profile',
+ useActiveRedisProfile: 'Use Active Redis Profile',
+ useActiveS3Profile: 'Use Active Profile',
+ idempotencyKey: 'Idempotency Key (Optional)',
+ secretConfigured: 'Configured already, leave empty to keep unchanged',
+ source: {
+ profileID: 'Profile ID (Unique)',
+ profileName: 'Profile Name',
+ setActive: 'Set as active after creation'
+ },
+ postgres: {
+ title: 'PostgreSQL',
+ host: 'Host',
+ port: 'Port',
+ user: 'User',
+ password: 'Password',
+ database: 'Database',
+ sslMode: 'SSL Mode',
+ containerName: 'Container Name (docker_exec mode)'
+ },
+ redis: {
+ title: 'Redis',
+ addr: 'Address (host:port)',
+ username: 'Username',
+ password: 'Password',
+ db: 'Database Index',
+ containerName: 'Container Name (docker_exec mode)'
+ },
+ s3: {
+ enabled: 'Enable S3 Upload',
+ profileID: 'Profile ID (Unique)',
+ profileName: 'Profile Name',
+ endpoint: 'Endpoint (Optional)',
+ region: 'Region',
+ bucket: 'Bucket',
+ accessKeyID: 'Access Key ID',
+ secretAccessKey: 'Secret Access Key',
+ prefix: 'Object Prefix',
+ forcePathStyle: 'Force Path Style',
+ useSSL: 'Use SSL',
+ setActive: 'Set as active after creation'
+ }
+ },
+ sourceProfiles: {
+ createTitle: 'Create Source Profile',
+ editTitle: 'Edit Source Profile',
+ empty: 'No source profiles yet, create one first',
+ deleteConfirm: 'Delete source profile {profileID}?',
+ columns: {
+ profile: 'Profile',
+ active: 'Active',
+ connection: 'Connection',
+ database: 'Database',
+ updatedAt: 'Updated At',
+ actions: 'Actions'
+ }
+ },
+ s3Profiles: {
+ createTitle: 'Create S3 Profile',
+ editTitle: 'Edit S3 Profile',
+ empty: 'No S3 profiles yet, create one first',
+ editHint: 'Click "Edit" to modify profile details in the right drawer.',
+ deleteConfirm: 'Delete S3 profile {profileID}?',
+ columns: {
+ profile: 'Profile',
+ active: 'Active',
+ storage: 'Storage',
+ updatedAt: 'Updated At',
+ actions: 'Actions'
+ }
+ },
+ history: {
+ total: '{count} jobs',
+ empty: 'No backup jobs yet',
+ columns: {
+ jobID: 'Job ID',
+ type: 'Type',
+ status: 'Status',
+ triggeredBy: 'Triggered By',
+ pgProfile: 'PostgreSQL Profile',
+ redisProfile: 'Redis Profile',
+ s3Profile: 'S3 Profile',
+ finishedAt: 'Finished At',
+ artifact: 'Artifact',
+ error: 'Error'
+ },
+ status: {
+ queued: 'Queued',
+ running: 'Running',
+ succeeded: 'Succeeded',
+ failed: 'Failed',
+ partial_succeeded: 'Partial Succeeded'
+ }
+ },
+ actions: {
+ refresh: 'Refresh Status',
+ disabledHint: 'Start datamanagementd first and ensure the socket is reachable.',
+ reloadConfig: 'Reload Config',
+ reloadSourceProfiles: 'Reload Source Profiles',
+ reloadProfiles: 'Reload Profiles',
+ newSourceProfile: 'New Source Profile',
+ saveConfig: 'Save Config',
+ configSaved: 'Configuration saved',
+ testS3: 'Test S3 Connection',
+ s3TestOK: 'S3 connection test succeeded',
+ s3TestFailed: 'S3 connection test failed',
+ newProfile: 'New Profile',
+ saveProfile: 'Save Profile',
+ activateProfile: 'Activate',
+ profileIDRequired: 'Profile ID is required',
+ profileNameRequired: 'Profile name is required',
+ profileSelectRequired: 'Select a profile to edit first',
+ profileCreated: 'S3 profile created',
+ profileSaved: 'S3 profile saved',
+ profileActivated: 'S3 profile activated',
+ profileDeleted: 'S3 profile deleted',
+ sourceProfileCreated: 'Source profile created',
+ sourceProfileSaved: 'Source profile saved',
+ sourceProfileActivated: 'Source profile activated',
+ sourceProfileDeleted: 'Source profile deleted',
+ createBackup: 'Create Backup Job',
+ jobCreated: 'Backup job created: {jobID} ({status})',
+ refreshJobs: 'Refresh Jobs',
+ loadMore: 'Load More'
+ }
+ },
+
// Users
users: {
title: 'User Management',
@@ -897,6 +1083,9 @@ export default {
noApiKeys: 'This user has no API keys',
group: 'Group',
none: 'None',
+ groupChangedSuccess: 'Group updated successfully',
+ groupChangedWithGrant: 'Group updated. User auto-granted access to "{group}"',
+ groupChangeFailed: 'Failed to update group',
noUsersYet: 'No users yet',
createFirstUser: 'Create your first user to get started.',
userCreated: 'User created successfully',
@@ -912,6 +1101,8 @@ export default {
failedToLoadApiKeys: 'Failed to load user API keys',
emailRequired: 'Please enter email',
concurrencyMin: 'Concurrency must be at least 1',
+ soraStorageQuota: 'Sora Storage Quota',
+ soraStorageQuotaHint: 'In GB, 0 means use group or system default quota',
amountRequired: 'Please enter a valid amount',
insufficientBalance: 'Insufficient balance',
deleteConfirm: "Are you sure you want to delete '{email}'? This action cannot be undone.",
@@ -1133,7 +1324,7 @@ export default {
},
imagePricing: {
title: 'Image Generation Pricing',
- description: 'Configure pricing for gemini-3-pro-image model. Leave empty to use default prices.'
+ description: 'Configure pricing for image generation models. Leave empty to use default prices.'
},
soraPricing: {
title: 'Sora Per-Request Pricing',
@@ -1141,7 +1332,9 @@ export default {
image360: 'Image 360px ($)',
image540: 'Image 540px ($)',
video: 'Video (standard) ($)',
- videoHd: 'Video (Pro-HD) ($)'
+ videoHd: 'Video (Pro-HD) ($)',
+ storageQuota: 'Storage Quota',
+ storageQuotaHint: 'In GB, set the Sora storage quota for users in this group. 0 means use system default'
},
claudeCode: {
title: 'Claude Code Client Restriction',
@@ -1378,6 +1571,10 @@ export default {
codeAssist: 'Code Assist',
antigravityOauth: 'Antigravity OAuth',
antigravityApikey: 'Connect via Base URL + API Key',
+ soraApiKey: 'API Key / Upstream',
+ soraApiKeyHint: 'Connect to another Sub2API or compatible API',
+ soraBaseUrlRequired: 'Sora API Key account requires a Base URL',
+ soraBaseUrlInvalidScheme: 'Base URL must start with http:// or https://',
upstream: 'Upstream',
upstreamDesc: 'Connect via Base URL + API Key'
},
@@ -1426,7 +1623,19 @@ export default {
sessions: {
full: 'Active sessions full, new sessions must wait (idle timeout: {idle} min)',
normal: 'Active sessions normal (idle timeout: {idle} min)'
- }
+ },
+ rpm: {
+ full: 'RPM limit reached',
+ warning: 'RPM approaching limit',
+ normal: 'RPM normal',
+ tieredNormal: 'RPM limit (Tiered) - Normal',
+ tieredWarning: 'RPM limit (Tiered) - Approaching limit',
+ tieredStickyOnly: 'RPM limit (Tiered) - Sticky only | Buffer: {buffer}',
+ tieredBlocked: 'RPM limit (Tiered) - Blocked | Buffer: {buffer}',
+ stickyExemptNormal: 'RPM limit (Sticky Exempt) - Normal',
+ stickyExemptWarning: 'RPM limit (Sticky Exempt) - Approaching limit',
+ stickyExemptOver: 'RPM limit (Sticky Exempt) - Over limit, sticky only'
+ },
},
tempUnschedulable: {
title: 'Temp Unschedulable',
@@ -1505,7 +1714,8 @@ export default {
partialSuccess: 'Partially updated: {success} succeeded, {failed} failed',
failed: 'Bulk update failed',
noSelection: 'Please select accounts to edit',
- noFieldsSelected: 'Select at least one field to update'
+ noFieldsSelected: 'Select at least one field to update',
+ mixedPlatformWarning: 'Selected accounts span multiple platforms ({platforms}). Model mapping presets shown are combined — ensure mappings are appropriate for each platform.'
},
bulkDeleteTitle: 'Bulk Delete Accounts',
bulkDeleteConfirm: 'Delete the selected {count} account(s)? This action cannot be undone.',
@@ -1542,6 +1752,24 @@ export default {
oauthPassthrough: 'Auto passthrough (auth only)',
oauthPassthroughDesc:
'When enabled, this OpenAI account uses automatic passthrough: the gateway forwards request/response as-is and only swaps auth, while keeping billing/concurrency/audit and necessary safety filtering.',
+ responsesWebsocketsV2: 'Responses WebSocket v2',
+ responsesWebsocketsV2Desc:
+ 'Disabled by default. Enable to allow responses_websockets_v2 capability (still gated by global and account-type switches).',
+ wsMode: 'WS mode',
+ wsModeDesc: 'Only applies to the current OpenAI account type.',
+ wsModeOff: 'Off (off)',
+ wsModeShared: 'Shared (shared)',
+ wsModeDedicated: 'Dedicated (dedicated)',
+ wsModeConcurrencyHint:
+ 'When WS mode is enabled, account concurrency becomes the WS connection pool limit for this account.',
+ oauthResponsesWebsocketsV2: 'OAuth WebSocket Mode',
+ oauthResponsesWebsocketsV2Desc:
+ 'Only applies to OpenAI OAuth. This account can use OpenAI WebSocket Mode only when enabled.',
+ apiKeyResponsesWebsocketsV2: 'API Key WebSocket Mode',
+ apiKeyResponsesWebsocketsV2Desc:
+ 'Only applies to OpenAI API Key. This account can use OpenAI WebSocket Mode only when enabled.',
+ responsesWebsocketsV2PassthroughHint:
+ 'Automatic passthrough is currently enabled: it only affects HTTP passthrough and does not disable WS mode.',
codexCLIOnly: 'Codex official clients only',
codexCLIOnlyDesc:
'Only applies to OpenAI OAuth. When enabled, only Codex official client families are allowed; when disabled, the gateway bypasses this restriction and keeps existing behavior.',
@@ -1622,6 +1850,22 @@ export default {
idleTimeoutPlaceholder: '5',
idleTimeoutHint: 'Sessions will be released after idle timeout'
},
+ rpmLimit: {
+ label: 'RPM Limit',
+ hint: 'Limit requests per minute to protect upstream accounts',
+ baseRpm: 'Base RPM',
+ baseRpmPlaceholder: '15',
+ baseRpmHint: 'Max requests per minute, 0 or empty means no limit',
+ strategy: 'RPM Strategy',
+ strategyTiered: 'Tiered Model',
+ strategyStickyExempt: 'Sticky Exempt',
+ strategyTieredHint: 'Green → Yellow → Sticky only → Blocked, progressive throttling',
+ strategyStickyExemptHint: 'Only sticky sessions allowed when over limit',
+ strategyHint: 'Tiered: gradually restrict when exceeded; Sticky Exempt: existing sessions unrestricted',
+ stickyBuffer: 'Sticky Buffer',
+ stickyBufferPlaceholder: 'Default: 20% of base RPM',
+ stickyBufferHint: 'Extra requests allowed for sticky sessions after exceeding base RPM. Leave empty to use default (20% of base RPM, min 1)'
+ },
tlsFingerprint: {
label: 'TLS Fingerprint Simulation',
hint: 'Simulate Node.js/Claude Code client TLS fingerprint'
@@ -1751,6 +1995,15 @@ export default {
sessionTokenAuth: 'Manual ST Input',
sessionTokenDesc: 'Enter your existing Sora Session Token(s). Supports batch input (one per line). The system will automatically validate and create accounts.',
sessionTokenPlaceholder: 'Paste your Sora Session Token...\nSupports multiple, one per line',
+ sessionTokenRawLabel: 'Raw Input',
+ sessionTokenRawPlaceholder: 'Paste /api/auth/session raw payload or Session Token...',
+ sessionTokenRawHint: 'You can paste full JSON. The system will auto-parse ST and AT.',
+ openSessionUrl: 'Open Fetch URL',
+ copySessionUrl: 'Copy URL',
+ sessionUrlHint: 'This URL usually returns AT. If sessionToken is absent, copy __Secure-next-auth.session-token from browser cookies as ST.',
+ parsedSessionTokensLabel: 'Parsed ST',
+ parsedSessionTokensEmpty: 'No ST parsed. Please check your input.',
+ parsedAccessTokensLabel: 'Parsed AT',
validating: 'Validating...',
validateAndCreate: 'Validate & Create Account',
pleaseEnterRefreshToken: 'Please enter Refresh Token',
@@ -2001,6 +2254,7 @@ export default {
selectTestModel: 'Select Test Model',
testModel: 'Test model',
testPrompt: 'Prompt: "hi"',
+ soraUpstreamBaseUrlHint: 'Upstream Sora service URL (another Sub2API instance or compatible API)',
soraTestHint: 'Sora test runs connectivity and capability checks (/backend/me, subscription, Sora2 invite and remaining quota).',
soraTestTarget: 'Target: Sora account capability',
soraTestMode: 'Mode: Connectivity + Capability checks',
@@ -2046,7 +2300,7 @@ export default {
geminiFlashDaily: 'Flash',
gemini3Pro: 'G3P',
gemini3Flash: 'G3F',
- gemini3Image: 'G3I',
+ gemini3Image: 'GImage',
claude: 'Claude'
},
tier: {
@@ -2091,6 +2345,8 @@ export default {
dataExportConfirm: 'Confirm Export',
dataExported: 'Data exported successfully',
dataExportFailed: 'Failed to export data',
+ copyProxyUrl: 'Copy Proxy URL',
+ urlCopied: 'Proxy URL copied',
searchProxies: 'Search proxies...',
allProtocols: 'All Protocols',
allStatus: 'All Status',
@@ -2104,6 +2360,7 @@ export default {
name: 'Name',
protocol: 'Protocol',
address: 'Address',
+ auth: 'Auth',
location: 'Location',
status: 'Status',
accounts: 'Accounts',
@@ -3298,7 +3555,23 @@ export default {
defaultBalance: 'Default Balance',
defaultBalanceHint: 'Initial balance for new users',
defaultConcurrency: 'Default Concurrency',
- defaultConcurrencyHint: 'Maximum concurrent requests for new users'
+ defaultConcurrencyHint: 'Maximum concurrent requests for new users',
+ defaultSubscriptions: 'Default Subscriptions',
+ defaultSubscriptionsHint: 'Auto-assign these subscriptions when a new user is created or registered',
+ addDefaultSubscription: 'Add Default Subscription',
+ defaultSubscriptionsEmpty: 'No default subscriptions configured.',
+ defaultSubscriptionsDuplicate:
+ 'Duplicate subscription group: {groupId}. Each group can only appear once.',
+ subscriptionGroup: 'Subscription Group',
+ subscriptionValidityDays: 'Validity (days)'
+ },
+ claudeCode: {
+ title: 'Claude Code Settings',
+ description: 'Control Claude Code client access requirements',
+ minVersion: 'Minimum Version',
+ minVersionPlaceholder: 'e.g. 2.1.63',
+ minVersionHint:
+ 'Reject Claude Code clients below this version (semver format). Leave empty to disable version check.'
},
site: {
title: 'Site Settings',
@@ -3334,15 +3607,23 @@ export default {
hideCcsImportButtonHint: 'When enabled, the "Import to CCS" button will be hidden on the API Keys page'
},
purchase: {
- title: 'Purchase Page',
- description: 'Show a "Purchase Subscription" entry in the sidebar and open the configured URL in an iframe',
- enabled: 'Show Purchase Entry',
+ title: 'Recharge / Subscription Page',
+ description: 'Show a "Recharge / Subscription" entry in the sidebar and open the configured URL in an iframe',
+ enabled: 'Show Recharge / Subscription Entry',
enabledHint: 'Only shown in standard mode (not simple mode)',
- url: 'Purchase URL',
+ url: 'Recharge / Subscription URL',
urlPlaceholder: 'https://example.com/purchase',
urlHint: 'Must be an absolute http(s) URL',
iframeWarning:
- '⚠️ iframe note: Some websites block embedding via X-Frame-Options or CSP (frame-ancestors). If the page is blank, provide an "Open in new tab" alternative.'
+ '⚠️ iframe note: Some websites block embedding via X-Frame-Options or CSP (frame-ancestors). If the page is blank, provide an "Open in new tab" alternative.',
+ integrationDoc: 'Payment Integration Docs',
+ integrationDocHint: 'Covers endpoint specs, idempotency semantics, and code samples'
+ },
+ soraClient: {
+ title: 'Sora Client',
+ description: 'Control whether to show the Sora client entry in the sidebar',
+ enabled: 'Enable Sora Client',
+ enabledHint: 'When enabled, the Sora entry will be shown in the sidebar for users to access Sora features'
},
smtp: {
title: 'SMTP Settings',
@@ -3415,6 +3696,60 @@ export default {
securityWarning: 'Warning: This key provides full admin access. Keep it secure.',
usage: 'Usage: Add to request header - x-api-key: '
},
+ soraS3: {
+ title: 'Sora S3 Storage',
+ description: 'Manage multiple Sora S3 endpoints and switch the active profile',
+ newProfile: 'New Profile',
+ reloadProfiles: 'Reload Profiles',
+ empty: 'No Sora S3 profiles yet, create one first',
+ createTitle: 'Create Sora S3 Profile',
+ editTitle: 'Edit Sora S3 Profile',
+ profileID: 'Profile ID',
+ profileName: 'Profile Name',
+ setActive: 'Set as active after creation',
+ saveProfile: 'Save Profile',
+ activateProfile: 'Activate',
+ profileCreated: 'Sora S3 profile created',
+ profileSaved: 'Sora S3 profile saved',
+ profileDeleted: 'Sora S3 profile deleted',
+ profileActivated: 'Sora S3 active profile switched',
+ profileIDRequired: 'Profile ID is required',
+ profileNameRequired: 'Profile name is required',
+ profileSelectRequired: 'Please select a profile first',
+ endpointRequired: 'S3 endpoint is required when enabled',
+ bucketRequired: 'Bucket is required when enabled',
+ accessKeyRequired: 'Access Key ID is required when enabled',
+ deleteConfirm: 'Delete Sora S3 profile {profileID}?',
+ columns: {
+ profile: 'Profile',
+ active: 'Active',
+ endpoint: 'Endpoint',
+ bucket: 'Bucket',
+ quota: 'Default Quota',
+ updatedAt: 'Updated At',
+ actions: 'Actions'
+ },
+ enabled: 'Enable S3 Storage',
+ enabledHint: 'When enabled, Sora generated media files will be automatically uploaded to S3 storage',
+ endpoint: 'S3 Endpoint',
+ region: 'Region',
+ bucket: 'Bucket',
+ prefix: 'Object Prefix',
+ accessKeyId: 'Access Key ID',
+ secretAccessKey: 'Secret Access Key',
+ secretConfigured: '(Configured, leave blank to keep)',
+ cdnUrl: 'CDN URL',
+ cdnUrlHint: 'Optional. When configured, files are accessed via CDN URL instead of presigned URLs',
+ forcePathStyle: 'Force Path Style',
+ defaultQuota: 'Default Storage Quota',
+ defaultQuotaHint: 'Default quota when not specified at user or group level. 0 means unlimited',
+ testConnection: 'Test Connection',
+ testing: 'Testing...',
+ testSuccess: 'S3 connection test successful',
+ testFailed: 'S3 connection test failed',
+ saved: 'Sora S3 settings saved successfully',
+ saveFailed: 'Failed to save Sora S3 settings'
+ },
streamTimeout: {
title: 'Stream Timeout Handling',
description: 'Configure account handling strategy when upstream response times out',
@@ -3566,16 +3901,16 @@ export default {
retry: 'Retry'
},
- // Purchase Subscription Page
+ // Recharge / Subscription Page
purchase: {
- title: 'Purchase Subscription',
- description: 'Purchase a subscription via the embedded page',
+ title: 'Recharge / Subscription',
+ description: 'Recharge balance or purchase subscription via the embedded page',
openInNewTab: 'Open in new tab',
notEnabledTitle: 'Feature not enabled',
- notEnabledDesc: 'The administrator has not enabled the purchase page. Please contact admin.',
- notConfiguredTitle: 'Purchase URL not configured',
+ notEnabledDesc: 'The administrator has not enabled the recharge/subscription entry. Please contact admin.',
+ notConfiguredTitle: 'Recharge / Subscription URL not configured',
notConfiguredDesc:
- 'The administrator enabled the entry but has not configured a purchase URL. Please contact admin.'
+ 'The administrator enabled the entry but has not configured a recharge/subscription URL. Please contact admin.'
},
// Announcements Page
@@ -3773,5 +4108,93 @@ export default {
description: 'Click to confirm and create your API key.
⚠️ Important:- Copy the key (sk-xxx) immediately after creation
- Key is only shown once, need to regenerate if lost
🚀 How to Use:
Configure the key in any OpenAI-compatible client (like ChatBox, OpenCat, etc.) and start using!
👉 Click "Create" button
'
}
}
+ },
+
+ // Sora Studio
+ sora: {
+ title: 'Sora Studio',
+ description: 'Generate videos and images with Sora AI',
+ notEnabled: 'Feature Not Available',
+ notEnabledDesc: 'The Sora Studio feature has not been enabled by the administrator. Please contact your admin.',
+ tabGenerate: 'Generate',
+ tabLibrary: 'Library',
+ noActiveGenerations: 'No active generations',
+ startGenerating: 'Enter a prompt below to start creating',
+ storage: 'Storage',
+ promptPlaceholder: 'Describe what you want to create...',
+ generate: 'Generate',
+ generating: 'Generating...',
+ selectModel: 'Select Model',
+ statusPending: 'Pending',
+ statusGenerating: 'Generating',
+ statusCompleted: 'Completed',
+ statusFailed: 'Failed',
+ statusCancelled: 'Cancelled',
+ cancel: 'Cancel',
+ delete: 'Delete',
+ save: 'Save to Cloud',
+ saved: 'Saved',
+ retry: 'Retry',
+ download: 'Download',
+ justNow: 'Just now',
+ minutesAgo: '{n} min ago',
+ hoursAgo: '{n} hr ago',
+ noSavedWorks: 'No saved works',
+ saveWorksHint: 'Save your completed generations to the library',
+ filterAll: 'All',
+ filterVideo: 'Video',
+ filterImage: 'Image',
+ confirmDelete: 'Are you sure you want to delete this work?',
+ loading: 'Loading...',
+ loadMore: 'Load More',
+ noStorageWarningTitle: 'No Storage Configured',
+ noStorageWarningDesc: 'Generated content is only available via temporary upstream links that expire in ~15 minutes. Consider configuring S3 storage.',
+ mediaTypeVideo: 'Video',
+ mediaTypeImage: 'Image',
+ notificationCompleted: 'Generation Complete',
+ notificationFailed: 'Generation Failed',
+ notificationCompletedBody: 'Your {model} task has completed',
+ notificationFailedBody: 'Your {model} task has failed',
+ upstreamExpiresSoon: 'Expiring soon',
+ upstreamExpired: 'Link expired',
+ upstreamCountdown: '{time} remaining',
+ previewTitle: 'Preview',
+ closePreview: 'Close',
+ beforeUnloadWarning: 'You have unsaved generated content. Are you sure you want to leave?',
+ downloadTitle: 'Download Generated Content',
+ downloadExpirationWarning: 'This link expires in approximately 15 minutes. Please download and save promptly.',
+ downloadNow: 'Download Now',
+ referenceImage: 'Reference Image',
+ removeImage: 'Remove',
+ imageTooLarge: 'Image size cannot exceed 20MB',
+ // Sora dark theme additions
+ welcomeTitle: 'Turn your imagination into video',
+ welcomeSubtitle: 'Enter a description and Sora will create realistic videos or images for you. Try the examples below to get started.',
+ queueTasks: 'tasks',
+ queueWaiting: 'Queued',
+ waiting: 'Waiting',
+ waited: 'Waited',
+ errorCategory: 'Content Policy Violation',
+ savedToCloud: 'Saved to Cloud',
+ downloadLocal: 'Download',
+ canDownload: 'to download',
+ regenrate: 'Regenerate',
+ creatorPlaceholder: 'Describe the video or image you want to create...',
+ videoModels: 'Video Models',
+ imageModels: 'Image Models',
+ noStorageConfigured: 'No Storage',
+ selectCredential: 'Select Credential',
+ apiKeys: 'API Keys',
+ subscriptions: 'Subscriptions',
+ subscription: 'Subscription',
+ noCredentialHint: 'Please create an API Key or contact admin for subscription',
+ uploadReference: 'Upload reference image',
+ generatingCount: 'Generating {current}/{max}',
+ noStorageToastMessage: 'Cloud storage is not configured. Please use "Download" to save files after generation, otherwise they will be lost.',
+ galleryCount: '{count} works',
+ galleryEmptyTitle: 'No works yet',
+ galleryEmptyDesc: 'Your creations will be displayed here. Go to the generate page to start your first creation.',
+ startCreating: 'Start Creating',
+ yesterday: 'Yesterday'
}
}
diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts
index 2720a50f..53818d1a 100644
--- a/frontend/src/i18n/locales/zh.ts
+++ b/frontend/src/i18n/locales/zh.ts
@@ -270,6 +270,7 @@ export default {
redeemCodes: '兑换码',
ops: '运维监控',
promoCodes: '优惠码',
+ dataManagement: '数据管理',
settings: '系统设置',
myAccount: '我的账户',
lightMode: '浅色模式',
@@ -279,8 +280,9 @@ export default {
logout: '退出登录',
github: 'GitHub',
mySubscriptions: '我的订阅',
- buySubscription: '购买订阅',
- docs: '文档'
+ buySubscription: '充值/订阅',
+ docs: '文档',
+ sora: 'Sora 创作'
},
// Auth
@@ -407,9 +409,12 @@ export default {
day: '按天',
hour: '按小时',
modelDistribution: '模型分布',
+ groupDistribution: '分组使用分布',
tokenUsageTrend: 'Token 使用趋势',
noDataAvailable: '暂无数据',
model: '模型',
+ group: '分组',
+ noGroup: '无分组',
requests: '请求',
tokens: 'Token',
actual: '实际',
@@ -501,6 +506,7 @@ export default {
claudeCode: 'Claude Code',
geminiCli: 'Gemini CLI',
codexCli: 'Codex CLI',
+ codexCliWs: 'Codex CLI (WebSocket)',
opencode: 'OpenCode'
},
antigravity: {
@@ -618,8 +624,10 @@ export default {
firstToken: '首 Token',
duration: '耗时',
time: '时间',
+ ws: 'WS',
stream: '流式',
sync: '同步',
+ unknown: '未知',
in: '输入',
out: '输出',
cacheRead: '读取',
@@ -841,9 +849,12 @@ export default {
day: '按天',
hour: '按小时',
modelDistribution: '模型分布',
+ groupDistribution: '分组使用分布',
tokenUsageTrend: 'Token 使用趋势',
noDataAvailable: '暂无数据',
model: '模型',
+ group: '分组',
+ noGroup: '无分组',
requests: '请求',
tokens: 'Token',
cache: '缓存',
@@ -862,6 +873,181 @@ export default {
failedToLoad: '加载仪表盘数据失败'
},
+ dataManagement: {
+ title: '数据管理',
+ description: '统一管理数据管理代理状态、对象存储配置和备份任务',
+ agent: {
+ title: '数据管理代理状态',
+ description: '系统会自动探测固定 Unix Socket,仅在可连通时启用数据管理功能。',
+ enabled: '数据管理代理已就绪,可继续进行数据管理操作。',
+ disabled: '数据管理代理不可用,当前仅可查看诊断信息。',
+ socketPath: 'Socket 路径',
+ version: '版本',
+ status: '状态',
+ uptime: '运行时长',
+ reasonLabel: '不可用原因',
+ reason: {
+ DATA_MANAGEMENT_AGENT_SOCKET_MISSING: '未检测到数据管理 Socket 文件',
+ DATA_MANAGEMENT_AGENT_UNAVAILABLE: '数据管理代理不可连通',
+ BACKUP_AGENT_SOCKET_MISSING: '未检测到备份 Socket 文件',
+ BACKUP_AGENT_UNAVAILABLE: '备份代理不可连通',
+ UNKNOWN: '未知原因'
+ }
+ },
+ sections: {
+ config: {
+ title: '备份配置',
+ description: '配置备份源、保留策略与 S3 存储参数。'
+ },
+ s3: {
+ title: 'S3 对象存储',
+ description: '配置并测试备份产物上传到标准 S3 对象存储。'
+ },
+ backup: {
+ title: '备份操作',
+ description: '触发 PostgreSQL、Redis 与全量备份任务。'
+ },
+ history: {
+ title: '备份历史',
+ description: '查看备份任务执行状态、错误与产物信息。'
+ }
+ },
+ form: {
+ sourceMode: '源模式',
+ backupRoot: '备份根目录',
+ activePostgresProfile: '当前激活 PostgreSQL 配置',
+ activeRedisProfile: '当前激活 Redis 配置',
+ activeS3Profile: '当前激活 S3 账号',
+ retentionDays: '保留天数',
+ keepLast: '至少保留最近任务数',
+ uploadToS3: '上传到 S3',
+ useActivePostgresProfile: '使用当前激活 PostgreSQL 配置',
+ useActiveRedisProfile: '使用当前激活 Redis 配置',
+ useActiveS3Profile: '使用当前激活账号',
+ idempotencyKey: '幂等键(可选)',
+ secretConfigured: '已配置,留空不变',
+ source: {
+ profileID: '配置 ID(唯一)',
+ profileName: '配置名称',
+ setActive: '创建后立即设为激活配置'
+ },
+ postgres: {
+ title: 'PostgreSQL',
+ host: '主机',
+ port: '端口',
+ user: '用户名',
+ password: '密码',
+ database: '数据库',
+ sslMode: 'SSL 模式',
+ containerName: '容器名(docker_exec 模式)'
+ },
+ redis: {
+ title: 'Redis',
+ addr: '地址(host:port)',
+ username: '用户名',
+ password: '密码',
+ db: '数据库编号',
+ containerName: '容器名(docker_exec 模式)'
+ },
+ s3: {
+ enabled: '启用 S3 上传',
+ profileID: '账号 ID(唯一)',
+ profileName: '账号名称',
+ endpoint: 'Endpoint(可选)',
+ region: 'Region',
+ bucket: 'Bucket',
+ accessKeyID: 'Access Key ID',
+ secretAccessKey: 'Secret Access Key',
+ prefix: '对象前缀',
+ forcePathStyle: '强制 path-style',
+ useSSL: '使用 SSL',
+ setActive: '创建后立即设为激活账号'
+ }
+ },
+ sourceProfiles: {
+ createTitle: '创建数据源配置',
+ editTitle: '编辑数据源配置',
+ empty: '暂无配置,请先创建',
+ deleteConfirm: '确定删除配置 {profileID} 吗?',
+ columns: {
+ profile: '配置',
+ active: '激活状态',
+ connection: '连接信息',
+ database: '数据库',
+ updatedAt: '更新时间',
+ actions: '操作'
+ }
+ },
+ s3Profiles: {
+ createTitle: '创建 S3 账号',
+ editTitle: '编辑 S3 账号',
+ empty: '暂无 S3 账号,请先创建',
+ editHint: '点击“编辑”将在右侧抽屉中修改账号信息。',
+ deleteConfirm: '确定删除 S3 账号 {profileID} 吗?',
+ columns: {
+ profile: '账号',
+ active: '激活状态',
+ storage: '存储配置',
+ updatedAt: '更新时间',
+ actions: '操作'
+ }
+ },
+ history: {
+ total: '共 {count} 条',
+ empty: '暂无备份任务',
+ columns: {
+ jobID: '任务 ID',
+ type: '类型',
+ status: '状态',
+ triggeredBy: '触发人',
+ pgProfile: 'PostgreSQL 配置',
+ redisProfile: 'Redis 配置',
+ s3Profile: 'S3 账号',
+ finishedAt: '完成时间',
+ artifact: '产物',
+ error: '错误'
+ },
+ status: {
+ queued: '排队中',
+ running: '执行中',
+ succeeded: '成功',
+ failed: '失败',
+ partial_succeeded: '部分成功'
+ }
+ },
+ actions: {
+ refresh: '刷新状态',
+ disabledHint: '请先启动 datamanagementd 并确认 Socket 可连通。',
+ reloadConfig: '加载配置',
+ reloadSourceProfiles: '刷新数据源配置',
+ reloadProfiles: '刷新账号列表',
+ newSourceProfile: '新建数据源配置',
+ saveConfig: '保存配置',
+ configSaved: '配置保存成功',
+ testS3: '测试 S3 连接',
+ s3TestOK: 'S3 连接测试成功',
+ s3TestFailed: 'S3 连接测试失败',
+ newProfile: '新建账号',
+ saveProfile: '保存账号',
+ activateProfile: '设为激活',
+ profileIDRequired: '请输入账号 ID',
+ profileNameRequired: '请输入账号名称',
+ profileSelectRequired: '请先选择要编辑的账号',
+ profileCreated: 'S3 账号创建成功',
+ profileSaved: 'S3 账号保存成功',
+ profileActivated: 'S3 账号已切换为激活',
+ profileDeleted: 'S3 账号删除成功',
+ sourceProfileCreated: '数据源配置创建成功',
+ sourceProfileSaved: '数据源配置保存成功',
+ sourceProfileActivated: '数据源配置已切换为激活',
+ sourceProfileDeleted: '数据源配置删除成功',
+ createBackup: '创建备份任务',
+ jobCreated: '备份任务已创建:{jobID}({status})',
+ refreshJobs: '刷新任务',
+ loadMore: '加载更多'
+ }
+ },
+
// Users Management
users: {
title: '用户管理',
@@ -925,6 +1111,9 @@ export default {
noApiKeys: '此用户暂无 API 密钥',
group: '分组',
none: '无',
+ groupChangedSuccess: '分组修改成功',
+ groupChangedWithGrant: '分组修改成功,已自动为用户添加「{group}」分组权限',
+ groupChangeFailed: '分组修改失败',
noUsersYet: '暂无用户',
createFirstUser: '创建您的第一个用户以开始使用系统',
userCreated: '用户创建成功',
@@ -978,6 +1167,8 @@ export default {
failedToAdjust: '调整失败',
emailRequired: '请输入邮箱',
concurrencyMin: '并发数不能小于1',
+ soraStorageQuota: 'Sora 存储配额',
+ soraStorageQuotaHint: '单位 GB,0 表示使用分组或系统默认配额',
amountRequired: '请输入有效金额',
insufficientBalance: '余额不足',
setAllowedGroups: '设置允许分组',
@@ -1220,7 +1411,7 @@ export default {
},
imagePricing: {
title: '图片生成计费',
- description: '配置 gemini-3-pro-image 模型的图片生成价格,留空则使用默认价格'
+ description: '配置图片生成模型的图片生成价格,留空则使用默认价格'
},
soraPricing: {
title: 'Sora 按次计费',
@@ -1228,7 +1419,9 @@ export default {
image360: '图片 360px ($)',
image540: '图片 540px ($)',
video: '视频(标准)($)',
- videoHd: '视频(Pro-HD)($)'
+ videoHd: '视频(Pro-HD)($)',
+ storageQuota: '存储配额',
+ storageQuotaHint: '单位 GB,设置该分组用户的 Sora 存储配额上限,0 表示使用系统默认'
},
claudeCode: {
title: 'Claude Code 客户端限制',
@@ -1481,7 +1674,19 @@ export default {
sessions: {
full: '活跃会话已满,新会话需等待(空闲超时:{idle}分钟)',
normal: '活跃会话正常(空闲超时:{idle}分钟)'
- }
+ },
+ rpm: {
+ full: '已达 RPM 上限',
+ warning: 'RPM 接近上限',
+ normal: 'RPM 正常',
+ tieredNormal: 'RPM 限制 (三区模型) - 正常',
+ tieredWarning: 'RPM 限制 (三区模型) - 接近阈值',
+ tieredStickyOnly: 'RPM 限制 (三区模型) - 仅粘性会话 | 缓冲区: {buffer}',
+ tieredBlocked: 'RPM 限制 (三区模型) - 已阻塞 | 缓冲区: {buffer}',
+ stickyExemptNormal: 'RPM 限制 (粘性豁免) - 正常',
+ stickyExemptWarning: 'RPM 限制 (粘性豁免) - 接近阈值',
+ stickyExemptOver: 'RPM 限制 (粘性豁免) - 超限,仅粘性会话'
+ },
},
clearRateLimit: '清除速率限制',
testConnection: '测试连接',
@@ -1512,6 +1717,10 @@ export default {
codeAssist: 'Code Assist',
antigravityOauth: 'Antigravity OAuth',
antigravityApikey: '通过 Base URL + API Key 连接',
+ soraApiKey: 'API Key / 上游透传',
+ soraApiKeyHint: '连接另一个 Sub2API 或兼容 API',
+ soraBaseUrlRequired: 'Sora apikey 账号必须设置上游地址(Base URL)',
+ soraBaseUrlInvalidScheme: 'Base URL 必须以 http:// 或 https:// 开头',
upstream: '对接上游',
upstreamDesc: '通过 Base URL + API Key 连接上游',
api_key: 'API Key',
@@ -1582,7 +1791,7 @@ export default {
geminiFlashDaily: 'Flash',
gemini3Pro: 'G3P',
gemini3Flash: 'G3F',
- gemini3Image: 'G3I',
+ gemini3Image: 'GImage',
claude: 'Claude'
},
tier: {
@@ -1652,7 +1861,8 @@ export default {
partialSuccess: '部分更新成功:成功 {success} 个,失败 {failed} 个',
failed: '批量更新失败',
noSelection: '请选择要编辑的账号',
- noFieldsSelected: '请至少选择一个要更新的字段'
+ noFieldsSelected: '请至少选择一个要更新的字段',
+ mixedPlatformWarning: '所选账号跨越多个平台({platforms})。显示的模型映射预设为合并结果——请确保映射对每个平台都适用。'
},
bulkDeleteTitle: '批量删除账号',
bulkDeleteConfirm: '确定要删除选中的 {count} 个账号吗?此操作无法撤销。',
@@ -1691,6 +1901,22 @@ export default {
oauthPassthrough: '自动透传(仅替换认证)',
oauthPassthroughDesc:
'开启后,该 OpenAI 账号将自动透传请求与响应,仅替换认证并保留计费/并发/审计及必要安全过滤;如遇兼容性问题可随时关闭回滚。',
+ responsesWebsocketsV2: 'Responses WebSocket v2',
+ responsesWebsocketsV2Desc:
+ '默认关闭。开启后可启用 responses_websockets_v2 协议能力(受网关全局开关与账号类型开关约束)。',
+ wsMode: 'WS mode',
+ wsModeDesc: '仅对当前 OpenAI 账号类型生效。',
+ wsModeOff: '关闭(off)',
+ wsModeShared: '共享(shared)',
+ wsModeDedicated: '独享(dedicated)',
+ wsModeConcurrencyHint: '启用 WS mode 后,该账号并发数将作为该账号 WS 连接池上限。',
+ oauthResponsesWebsocketsV2: 'OAuth WebSocket Mode',
+ oauthResponsesWebsocketsV2Desc:
+ '仅对 OpenAI OAuth 生效。开启后该账号才允许使用 OpenAI WebSocket Mode 协议。',
+ apiKeyResponsesWebsocketsV2: 'API Key WebSocket Mode',
+ apiKeyResponsesWebsocketsV2Desc:
+ '仅对 OpenAI API Key 生效。开启后该账号才允许使用 OpenAI WebSocket Mode 协议。',
+ responsesWebsocketsV2PassthroughHint: '当前已开启自动透传:仅影响 HTTP 透传链路,不影响 WS mode。',
codexCLIOnly: '仅允许 Codex 官方客户端',
codexCLIOnlyDesc: '仅对 OpenAI OAuth 生效。开启后仅允许 Codex 官方客户端家族访问;关闭后完全绕过并保持原逻辑。',
modelRestrictionDisabledByPassthrough: '已开启自动透传:模型白名单/映射不会生效。',
@@ -1767,6 +1993,22 @@ export default {
idleTimeoutPlaceholder: '5',
idleTimeoutHint: '会话空闲超时后自动释放'
},
+ rpmLimit: {
+ label: 'RPM 限制',
+ hint: '限制每分钟请求数量,保护上游账号',
+ baseRpm: '基础 RPM',
+ baseRpmPlaceholder: '15',
+ baseRpmHint: '每分钟最大请求数,0 或留空表示不限制',
+ strategy: 'RPM 策略',
+ strategyTiered: '三区模型',
+ strategyStickyExempt: '粘性豁免',
+ strategyTieredHint: '绿区→黄区→仅粘性→阻塞,逐步限流',
+ strategyStickyExemptHint: '超限后仅允许粘性会话',
+ strategyHint: '三区模型: 超限后逐步限制; 粘性豁免: 已有会话不受限',
+ stickyBuffer: '粘性缓冲区',
+ stickyBufferPlaceholder: '默认: base RPM 的 20%',
+ stickyBufferHint: '超过 base RPM 后,粘性会话额外允许的请求数。为空则使用默认值(base RPM 的 20%,最小为 1)'
+ },
tlsFingerprint: {
label: 'TLS 指纹模拟',
hint: '模拟 Node.js/Claude Code 客户端的 TLS 指纹'
@@ -1890,6 +2132,15 @@ export default {
sessionTokenAuth: '手动输入 ST',
sessionTokenDesc: '输入您已有的 Sora Session Token,支持批量输入(每行一个),系统将自动验证并创建账号。',
sessionTokenPlaceholder: '粘贴您的 Sora Session Token...\n支持多个,每行一个',
+ sessionTokenRawLabel: '原始字符串',
+ sessionTokenRawPlaceholder: '粘贴 /api/auth/session 原始数据或 Session Token...',
+ sessionTokenRawHint: '支持粘贴完整 JSON,系统会自动解析 ST 和 AT。',
+ openSessionUrl: '打开获取链接',
+ copySessionUrl: '复制链接',
+ sessionUrlHint: '该链接通常可获取 AT。若返回中无 sessionToken,请从浏览器 Cookie 复制 __Secure-next-auth.session-token 作为 ST。',
+ parsedSessionTokensLabel: '解析出的 ST',
+ parsedSessionTokensEmpty: '未解析到 ST,请检查输入内容',
+ parsedAccessTokensLabel: '解析出的 AT',
validating: '验证中...',
validateAndCreate: '验证并创建账号',
pleaseEnterRefreshToken: '请输入 Refresh Token',
@@ -2133,6 +2384,7 @@ export default {
selectTestModel: '选择测试模型',
testModel: '测试模型',
testPrompt: '提示词:"hi"',
+ soraUpstreamBaseUrlHint: '上游 Sora 服务地址(另一个 Sub2API 实例或兼容 API)',
soraTestHint: 'Sora 测试将执行连通性与能力检测(/backend/me、订阅信息、Sora2 邀请码与剩余额度)。',
soraTestTarget: '检测目标:Sora 账号能力',
soraTestMode: '模式:连通性 + 能力探测',
@@ -2207,6 +2459,7 @@ export default {
name: '名称',
protocol: '协议',
address: '地址',
+ auth: '认证',
location: '地理位置',
status: '状态',
accounts: '账号数',
@@ -2234,6 +2487,8 @@ export default {
allStatuses: '全部状态'
},
// Additional keys used in ProxiesView
+ copyProxyUrl: '复制代理 URL',
+ urlCopied: '代理 URL 已复制',
allProtocols: '全部协议',
allStatus: '全部状态',
searchProxies: '搜索代理...',
@@ -3470,7 +3725,21 @@ export default {
defaultBalance: '默认余额',
defaultBalanceHint: '新用户的初始余额',
defaultConcurrency: '默认并发数',
- defaultConcurrencyHint: '新用户的最大并发请求数'
+ defaultConcurrencyHint: '新用户的最大并发请求数',
+ defaultSubscriptions: '默认订阅列表',
+ defaultSubscriptionsHint: '新用户创建或注册时自动分配这些订阅',
+ addDefaultSubscription: '添加默认订阅',
+ defaultSubscriptionsEmpty: '未配置默认订阅。新用户不会自动获得订阅套餐。',
+ defaultSubscriptionsDuplicate: '默认订阅存在重复分组:{groupId}。每个分组只能出现一次。',
+ subscriptionGroup: '订阅分组',
+ subscriptionValidityDays: '有效期(天)'
+ },
+ claudeCode: {
+ title: 'Claude Code 设置',
+ description: '控制 Claude Code 客户端访问要求',
+ minVersion: '最低版本号',
+ minVersionPlaceholder: '例如 2.1.63',
+ minVersionHint: '拒绝低于此版本的 Claude Code 客户端请求(semver 格式)。留空则不检查版本。'
},
site: {
title: '站点设置',
@@ -3508,15 +3777,23 @@ export default {
hideCcsImportButtonHint: '启用后将在 API Keys 页面隐藏"导入 CCS"按钮'
},
purchase: {
- title: '购买订阅页面',
- description: '在侧边栏展示“购买订阅”入口,并在页面内通过 iframe 打开指定链接',
- enabled: '显示购买订阅入口',
+ title: '充值/订阅页面',
+ description: '在侧边栏展示“充值/订阅”入口,并在页面内通过 iframe 打开指定链接',
+ enabled: '显示充值/订阅入口',
enabledHint: '仅在标准模式(非简单模式)下展示',
- url: '购买页面 URL',
+ url: '充值/订阅页面 URL',
urlPlaceholder: 'https://example.com/purchase',
urlHint: '必须是完整的 http(s) 链接',
iframeWarning:
- '⚠️ iframe 提示:部分网站会通过 X-Frame-Options 或 CSP(frame-ancestors)禁止被 iframe 嵌入,出现空白时可引导用户使用“新窗口打开”。'
+ '⚠️ iframe 提示:部分网站会通过 X-Frame-Options 或 CSP(frame-ancestors)禁止被 iframe 嵌入,出现空白时可引导用户使用”新窗口打开”。',
+ integrationDoc: '支付集成文档',
+ integrationDocHint: '包含接口说明、幂等语义及示例代码'
+ },
+ soraClient: {
+ title: 'Sora 客户端',
+ description: '控制是否在侧边栏展示 Sora 客户端入口',
+ enabled: '启用 Sora 客户端',
+ enabledHint: '开启后,侧边栏将显示 Sora 入口,用户可访问 Sora 功能'
},
smtp: {
title: 'SMTP 设置',
@@ -3588,6 +3865,60 @@ export default {
securityWarning: '警告:此密钥拥有完整的管理员权限,请妥善保管。',
usage: '使用方法:在请求头中添加 x-api-key: '
},
+ soraS3: {
+ title: 'Sora S3 存储配置',
+ description: '以多配置列表方式管理 Sora S3 端点,并可切换生效配置',
+ newProfile: '新建配置',
+ reloadProfiles: '刷新列表',
+ empty: '暂无 Sora S3 配置,请先创建',
+ createTitle: '新建 Sora S3 配置',
+ editTitle: '编辑 Sora S3 配置',
+ profileID: '配置 ID',
+ profileName: '配置名称',
+ setActive: '创建后设为生效',
+ saveProfile: '保存配置',
+ activateProfile: '设为生效',
+ profileCreated: 'Sora S3 配置创建成功',
+ profileSaved: 'Sora S3 配置保存成功',
+ profileDeleted: 'Sora S3 配置删除成功',
+ profileActivated: 'Sora S3 生效配置已切换',
+ profileIDRequired: '请填写配置 ID',
+ profileNameRequired: '请填写配置名称',
+ profileSelectRequired: '请先选择配置',
+ endpointRequired: '启用时必须填写 S3 端点',
+ bucketRequired: '启用时必须填写存储桶',
+ accessKeyRequired: '启用时必须填写 Access Key ID',
+ deleteConfirm: '确定删除 Sora S3 配置 {profileID} 吗?',
+ columns: {
+ profile: '配置',
+ active: '生效状态',
+ endpoint: '端点',
+ bucket: '存储桶',
+ quota: '默认配额',
+ updatedAt: '更新时间',
+ actions: '操作'
+ },
+ enabled: '启用 S3 存储',
+ enabledHint: '启用后,Sora 生成的媒体文件将自动上传到 S3 存储',
+ endpoint: 'S3 端点',
+ region: '区域',
+ bucket: '存储桶',
+ prefix: '对象前缀',
+ accessKeyId: 'Access Key ID',
+ secretAccessKey: 'Secret Access Key',
+ secretConfigured: '(已配置,留空保持不变)',
+ cdnUrl: 'CDN URL',
+ cdnUrlHint: '可选,配置后使用 CDN URL 访问文件,否则使用预签名 URL',
+ forcePathStyle: '强制路径风格(Path Style)',
+ defaultQuota: '默认存储配额',
+ defaultQuotaHint: '未在用户或分组级别指定配额时的默认值,0 表示无限制',
+ testConnection: '测试连接',
+ testing: '测试中...',
+ testSuccess: 'S3 连接测试成功',
+ testFailed: 'S3 连接测试失败',
+ saved: 'Sora S3 设置保存成功',
+ saveFailed: '保存 Sora S3 设置失败'
+ },
streamTimeout: {
title: '流超时处理',
description: '配置上游响应超时时的账户处理策略,避免问题账户持续被选中',
@@ -3739,15 +4070,15 @@ export default {
retry: '重试'
},
- // Purchase Subscription Page
+ // Recharge / Subscription Page
purchase: {
- title: '购买订阅',
- description: '通过内嵌页面完成订阅购买',
+ title: '充值/订阅',
+ description: '通过内嵌页面完成充值/订阅',
openInNewTab: '新窗口打开',
notEnabledTitle: '该功能未开启',
- notEnabledDesc: '管理员暂未开启购买订阅入口,请联系管理员。',
- notConfiguredTitle: '购买链接未配置',
- notConfiguredDesc: '管理员已开启入口,但尚未配置购买订阅链接,请联系管理员。'
+ notEnabledDesc: '管理员暂未开启充值/订阅入口,请联系管理员。',
+ notConfiguredTitle: '充值/订阅链接未配置',
+ notConfiguredDesc: '管理员已开启入口,但尚未配置充值/订阅链接,请联系管理员。'
},
// Announcements Page
@@ -3971,5 +4302,93 @@ export default {
'点击确认创建您的 API 密钥。
⚠️ 重要:- 创建后请立即复制密钥(sk-xxx)
- 密钥只显示一次,丢失需重新生成
🚀 如何使用:
将密钥配置到支持 OpenAI 接口的任何客户端(如 ChatBox、OpenCat 等),即可开始使用!
👉 点击"创建"按钮
'
}
}
+ },
+
+ // Sora 创作
+ sora: {
+ title: 'Sora 创作',
+ description: '使用 Sora AI 生成视频与图片',
+ notEnabled: '功能未开放',
+ notEnabledDesc: '管理员尚未启用 Sora 创作功能,请联系管理员开通。',
+ tabGenerate: '生成',
+ tabLibrary: '作品库',
+ noActiveGenerations: '暂无生成任务',
+ startGenerating: '在下方输入提示词,开始创作',
+ storage: '存储',
+ promptPlaceholder: '描述你想创作的内容...',
+ generate: '生成',
+ generating: '生成中...',
+ selectModel: '选择模型',
+ statusPending: '等待中',
+ statusGenerating: '生成中',
+ statusCompleted: '已完成',
+ statusFailed: '失败',
+ statusCancelled: '已取消',
+ cancel: '取消',
+ delete: '删除',
+ save: '保存到云端',
+ saved: '已保存',
+ retry: '重试',
+ download: '下载',
+ justNow: '刚刚',
+ minutesAgo: '{n} 分钟前',
+ hoursAgo: '{n} 小时前',
+ noSavedWorks: '暂无保存的作品',
+ saveWorksHint: '生成完成后,将作品保存到作品库',
+ filterAll: '全部',
+ filterVideo: '视频',
+ filterImage: '图片',
+ confirmDelete: '确定删除此作品?',
+ loading: '加载中...',
+ loadMore: '加载更多',
+ noStorageWarningTitle: '未配置存储',
+ noStorageWarningDesc: '生成的内容仅通过上游临时链接提供,约 15 分钟后过期。建议管理员配置 S3 存储。',
+ mediaTypeVideo: '视频',
+ mediaTypeImage: '图片',
+ notificationCompleted: '生成完成',
+ notificationFailed: '生成失败',
+ notificationCompletedBody: '您的 {model} 任务已完成',
+ notificationFailedBody: '您的 {model} 任务失败了',
+ upstreamExpiresSoon: '即将过期',
+ upstreamExpired: '链接已过期',
+ upstreamCountdown: '剩余 {time}',
+ previewTitle: '作品预览',
+ closePreview: '关闭',
+ beforeUnloadWarning: '您有未保存的生成内容,确定要离开吗?',
+ downloadTitle: '下载生成内容',
+ downloadExpirationWarning: '此链接约 15 分钟后过期,请尽快下载保存。',
+ downloadNow: '立即下载',
+ referenceImage: '参考图',
+ removeImage: '移除',
+ imageTooLarge: '图片大小不能超过 20MB',
+ // Sora 暗色主题新增
+ welcomeTitle: '将你的想象力变成视频',
+ welcomeSubtitle: '输入一段描述,Sora 将为你创作逼真的视频或图片。尝试以下示例开始创作。',
+ queueTasks: '个任务',
+ queueWaiting: '队列中等待',
+ waiting: '等待中',
+ waited: '已等待',
+ errorCategory: '内容策略限制',
+ savedToCloud: '已保存到云端',
+ downloadLocal: '本地下载',
+ canDownload: '可下载',
+ regenrate: '重新生成',
+ creatorPlaceholder: '描述你想要生成的视频或图片...',
+ videoModels: '视频模型',
+ imageModels: '图片模型',
+ noStorageConfigured: '存储未配置',
+ selectCredential: '选择凭证',
+ apiKeys: 'API 密钥',
+ subscriptions: '订阅',
+ subscription: '订阅',
+ noCredentialHint: '请先创建 API Key 或联系管理员分配订阅',
+ uploadReference: '上传参考图片',
+ generatingCount: '正在生成 {current}/{max}',
+ noStorageToastMessage: '管理员未开通云存储,生成完成后请使用"本地下载"保存文件,否则将会丢失。',
+ galleryCount: '共 {count} 个作品',
+ galleryEmptyTitle: '还没有任何作品',
+ galleryEmptyDesc: '你的创作成果将会展示在这里。前往生成页,开始你的第一次创作吧。',
+ startCreating: '开始创作',
+ yesterday: '昨天'
}
}
diff --git a/frontend/src/main.ts b/frontend/src/main.ts
index 6b809ec2..51f6f0cc 100644
--- a/frontend/src/main.ts
+++ b/frontend/src/main.ts
@@ -6,7 +6,18 @@ import i18n, { initI18n } from './i18n'
import { useAppStore } from '@/stores/app'
import './style.css'
+function initThemeClass() {
+ const savedTheme = localStorage.getItem('theme')
+ const shouldUseDark =
+ savedTheme === 'dark' ||
+ (!savedTheme && window.matchMedia('(prefers-color-scheme: dark)').matches)
+ document.documentElement.classList.toggle('dark', shouldUseDark)
+}
+
async function bootstrap() {
+ // Apply theme class globally before app mount to keep all routes consistent.
+ initThemeClass()
+
const app = createApp(App)
const pinia = createPinia()
app.use(pinia)
diff --git a/frontend/src/router/index.ts b/frontend/src/router/index.ts
index c0bab4ec..125a5013 100644
--- a/frontend/src/router/index.ts
+++ b/frontend/src/router/index.ts
@@ -41,7 +41,8 @@ const routes: RouteRecordRaw[] = [
component: () => import('@/views/auth/LoginView.vue'),
meta: {
requiresAuth: false,
- title: 'Login'
+ title: 'Login',
+ titleKey: 'common.login'
}
},
{
@@ -50,7 +51,8 @@ const routes: RouteRecordRaw[] = [
component: () => import('@/views/auth/RegisterView.vue'),
meta: {
requiresAuth: false,
- title: 'Register'
+ title: 'Register',
+ titleKey: 'auth.createAccount'
}
},
{
@@ -86,7 +88,8 @@ const routes: RouteRecordRaw[] = [
component: () => import('@/views/auth/ForgotPasswordView.vue'),
meta: {
requiresAuth: false,
- title: 'Forgot Password'
+ title: 'Forgot Password',
+ titleKey: 'auth.forgotPasswordTitle'
}
},
{
@@ -188,6 +191,18 @@ const routes: RouteRecordRaw[] = [
descriptionKey: 'purchase.description'
}
},
+ {
+ path: '/sora',
+ name: 'Sora',
+ component: () => import('@/views/user/SoraView.vue'),
+ meta: {
+ requiresAuth: true,
+ requiresAdmin: false,
+ title: 'Sora',
+ titleKey: 'sora.title',
+ descriptionKey: 'sora.description'
+ }
+ },
// ==================== Admin Routes ====================
{
@@ -314,6 +329,18 @@ const routes: RouteRecordRaw[] = [
descriptionKey: 'admin.promo.description'
}
},
+ {
+ path: '/admin/data-management',
+ name: 'AdminDataManagement',
+ component: () => import('@/views/admin/DataManagementView.vue'),
+ meta: {
+ requiresAuth: true,
+ requiresAdmin: true,
+ title: 'Data Management',
+ titleKey: 'admin.dataManagement.title',
+ descriptionKey: 'admin.dataManagement.description'
+ }
+ },
{
path: '/admin/settings',
name: 'AdminSettings',
@@ -390,7 +417,7 @@ router.beforeEach((to, _from, next) => {
// Set page title
const appStore = useAppStore()
- document.title = resolveDocumentTitle(to.meta.title, appStore.siteName)
+ document.title = resolveDocumentTitle(to.meta.title, appStore.siteName, to.meta.titleKey as string)
// Check if route requires authentication
const requiresAuth = to.meta.requiresAuth !== false // Default to true
diff --git a/frontend/src/router/title.ts b/frontend/src/router/title.ts
index ed25ed1f..d6cdd2e9 100644
--- a/frontend/src/router/title.ts
+++ b/frontend/src/router/title.ts
@@ -1,9 +1,19 @@
+import { i18n } from '@/i18n'
+
/**
* 统一生成页面标题,避免多处写入 document.title 产生覆盖冲突。
+ * 优先使用 titleKey 通过 i18n 翻译,fallback 到静态 routeTitle。
*/
-export function resolveDocumentTitle(routeTitle: unknown, siteName?: string): string {
+export function resolveDocumentTitle(routeTitle: unknown, siteName?: string, titleKey?: string): string {
const normalizedSiteName = typeof siteName === 'string' && siteName.trim() ? siteName.trim() : 'TianShuAPI'
+ if (typeof titleKey === 'string' && titleKey.trim()) {
+ const translated = i18n.global.t(titleKey)
+ if (translated && translated !== titleKey) {
+ return `${translated} - ${normalizedSiteName}`
+ }
+ }
+
if (typeof routeTitle === 'string' && routeTitle.trim()) {
return `${routeTitle.trim()} - ${normalizedSiteName}`
}
diff --git a/frontend/src/stores/app.ts b/frontend/src/stores/app.ts
index c48b1cd8..7e61befa 100644
--- a/frontend/src/stores/app.ts
+++ b/frontend/src/stores/app.ts
@@ -328,6 +328,7 @@ export const useAppStore = defineStore('app', () => {
purchase_subscription_enabled: false,
purchase_subscription_url: '',
linuxdo_oauth_enabled: false,
+ sora_client_enabled: false,
version: siteVersion.value
}
}
diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts
index 03378b63..faf69f79 100644
--- a/frontend/src/types/index.ts
+++ b/frontend/src/types/index.ts
@@ -45,6 +45,9 @@ export interface AdminUser extends User {
group_rates?: Record
// 当前并发数(仅管理员列表接口返回)
current_concurrency?: number
+ // Sora 存储配额(字节)
+ sora_storage_quota_bytes: number
+ sora_storage_used_bytes: number
}
export interface LoginRequest {
@@ -91,6 +94,7 @@ export interface PublicSettings {
purchase_subscription_enabled: boolean
purchase_subscription_url: string
linuxdo_oauth_enabled: boolean
+ sora_client_enabled: boolean
version: string
}
@@ -363,6 +367,8 @@ export interface Group {
sora_image_price_540: number | null
sora_video_price_per_request: number | null
sora_video_price_per_request_hd: number | null
+ // Sora 存储配额(字节)
+ sora_storage_quota_bytes: number
// Claude Code 客户端限制
claude_code_only: boolean
fallback_group_id: number | null
@@ -445,6 +451,7 @@ export interface CreateGroupRequest {
sora_image_price_540?: number | null
sora_video_price_per_request?: number | null
sora_video_price_per_request_hd?: number | null
+ sora_storage_quota_bytes?: number
claude_code_only?: boolean
fallback_group_id?: number | null
fallback_group_id_on_invalid_request?: number | null
@@ -472,6 +479,7 @@ export interface UpdateGroupRequest {
sora_image_price_540?: number | null
sora_video_price_per_request?: number | null
sora_video_price_per_request_hd?: number | null
+ sora_storage_quota_bytes?: number
claude_code_only?: boolean
fallback_group_id?: number | null
fallback_group_id_on_invalid_request?: number | null
@@ -653,6 +661,11 @@ export interface Account {
max_sessions?: number | null
session_idle_timeout_minutes?: number | null
+ // RPM 限制(仅 Anthropic OAuth/SetupToken 账号有效)
+ base_rpm?: number | null
+ rpm_strategy?: string | null
+ rpm_sticky_buffer?: number | null
+
// TLS指纹伪装(仅 Anthropic OAuth/SetupToken 账号有效)
enable_tls_fingerprint?: boolean | null
@@ -667,6 +680,7 @@ export interface Account {
// 运行时状态(仅当启用对应限制时返回)
current_window_cost?: number | null // 当前窗口费用
active_sessions?: number | null // 当前活跃会话数
+ current_rpm?: number | null // 当前分钟 RPM 计数
}
// Account Usage types
@@ -859,6 +873,7 @@ export interface AdminDataImportResult {
// ==================== Usage & Redeem Types ====================
export type RedeemCodeType = 'balance' | 'concurrency' | 'subscription' | 'invitation'
+export type UsageRequestType = 'unknown' | 'sync' | 'stream' | 'ws_v2'
export interface UsageLog {
id: number
@@ -888,7 +903,9 @@ export interface UsageLog {
rate_multiplier: number
billing_type: number
+ request_type?: UsageRequestType
stream: boolean
+ openai_ws_mode?: boolean
duration_ms: number
first_token_ms: number | null
@@ -934,6 +951,7 @@ export interface UsageCleanupFilters {
account_id?: number
group_id?: number
model?: string | null
+ request_type?: UsageRequestType | null
stream?: boolean | null
billing_type?: number | null
}
@@ -1068,6 +1086,15 @@ export interface ModelStat {
actual_cost: number // 实际扣除
}
+export interface GroupStat {
+ group_id: number
+ group_name: string
+ requests: number
+ total_tokens: number
+ cost: number // 标准计费
+ actual_cost: number // 实际扣除
+}
+
export interface UserUsageTrendPoint {
date: string
user_id: number
@@ -1178,6 +1205,7 @@ export interface UsageQueryParams {
account_id?: number
group_id?: number
model?: string
+ request_type?: UsageRequestType
stream?: boolean
billing_type?: number | null
start_date?: string
diff --git a/frontend/src/utils/__tests__/openaiWsMode.spec.ts b/frontend/src/utils/__tests__/openaiWsMode.spec.ts
new file mode 100644
index 00000000..39f21aef
--- /dev/null
+++ b/frontend/src/utils/__tests__/openaiWsMode.spec.ts
@@ -0,0 +1,55 @@
+import { describe, expect, it } from 'vitest'
+import {
+ OPENAI_WS_MODE_DEDICATED,
+ OPENAI_WS_MODE_OFF,
+ OPENAI_WS_MODE_SHARED,
+ isOpenAIWSModeEnabled,
+ normalizeOpenAIWSMode,
+ openAIWSModeFromEnabled,
+ resolveOpenAIWSModeFromExtra
+} from '@/utils/openaiWsMode'
+
+describe('openaiWsMode utils', () => {
+ it('normalizes mode values', () => {
+ expect(normalizeOpenAIWSMode('off')).toBe(OPENAI_WS_MODE_OFF)
+ expect(normalizeOpenAIWSMode(' Shared ')).toBe(OPENAI_WS_MODE_SHARED)
+ expect(normalizeOpenAIWSMode('DEDICATED')).toBe(OPENAI_WS_MODE_DEDICATED)
+ expect(normalizeOpenAIWSMode('invalid')).toBeNull()
+ })
+
+ it('maps legacy enabled flag to mode', () => {
+ expect(openAIWSModeFromEnabled(true)).toBe(OPENAI_WS_MODE_SHARED)
+ expect(openAIWSModeFromEnabled(false)).toBe(OPENAI_WS_MODE_OFF)
+ expect(openAIWSModeFromEnabled('true')).toBeNull()
+ })
+
+ it('resolves by mode key first, then enabled, then fallback enabled keys', () => {
+ const extra = {
+ openai_oauth_responses_websockets_v2_mode: 'dedicated',
+ openai_oauth_responses_websockets_v2_enabled: false,
+ responses_websockets_v2_enabled: false
+ }
+ const mode = resolveOpenAIWSModeFromExtra(extra, {
+ modeKey: 'openai_oauth_responses_websockets_v2_mode',
+ enabledKey: 'openai_oauth_responses_websockets_v2_enabled',
+ fallbackEnabledKeys: ['responses_websockets_v2_enabled', 'openai_ws_enabled']
+ })
+ expect(mode).toBe(OPENAI_WS_MODE_DEDICATED)
+ })
+
+ it('falls back to default when nothing is present', () => {
+ const mode = resolveOpenAIWSModeFromExtra({}, {
+ modeKey: 'openai_apikey_responses_websockets_v2_mode',
+ enabledKey: 'openai_apikey_responses_websockets_v2_enabled',
+ fallbackEnabledKeys: ['responses_websockets_v2_enabled', 'openai_ws_enabled'],
+ defaultMode: OPENAI_WS_MODE_OFF
+ })
+ expect(mode).toBe(OPENAI_WS_MODE_OFF)
+ })
+
+ it('treats off as disabled and shared/dedicated as enabled', () => {
+ expect(isOpenAIWSModeEnabled(OPENAI_WS_MODE_OFF)).toBe(false)
+ expect(isOpenAIWSModeEnabled(OPENAI_WS_MODE_SHARED)).toBe(true)
+ expect(isOpenAIWSModeEnabled(OPENAI_WS_MODE_DEDICATED)).toBe(true)
+ })
+})
diff --git a/frontend/src/utils/__tests__/soraTokenParser.spec.ts b/frontend/src/utils/__tests__/soraTokenParser.spec.ts
new file mode 100644
index 00000000..816e5319
--- /dev/null
+++ b/frontend/src/utils/__tests__/soraTokenParser.spec.ts
@@ -0,0 +1,90 @@
+import { describe, expect, it } from 'vitest'
+import { parseSoraRawTokens } from '@/utils/soraTokenParser'
+
+describe('parseSoraRawTokens', () => {
+ it('parses sessionToken and accessToken from JSON payload', () => {
+ const payload = JSON.stringify({
+ user: { id: 'u1' },
+ accessToken: 'at-json-1',
+ sessionToken: 'st-json-1'
+ })
+
+ const result = parseSoraRawTokens(payload)
+
+ expect(result.sessionTokens).toEqual(['st-json-1'])
+ expect(result.accessTokens).toEqual(['at-json-1'])
+ })
+
+ it('supports plain session tokens (one per line)', () => {
+ const result = parseSoraRawTokens('st-1\nst-2')
+
+ expect(result.sessionTokens).toEqual(['st-1', 'st-2'])
+ expect(result.accessTokens).toEqual([])
+ })
+
+ it('supports non-standard object snippets via regex', () => {
+ const raw = "sessionToken: 'st-snippet', access_token: \"at-snippet\""
+ const result = parseSoraRawTokens(raw)
+
+ expect(result.sessionTokens).toEqual(['st-snippet'])
+ expect(result.accessTokens).toEqual(['at-snippet'])
+ })
+
+ it('keeps unique tokens and extracts JWT-like plain line as AT too', () => {
+ const jwt = 'eyJhbGciOiJIUzI1NiJ9.eyJzdWIiOiIxIn0.signature'
+ const raw = `st-dup\nst-dup\n${jwt}\n${JSON.stringify({ sessionToken: 'st-json', accessToken: jwt })}`
+ const result = parseSoraRawTokens(raw)
+
+ expect(result.sessionTokens).toEqual(['st-json', 'st-dup'])
+ expect(result.accessTokens).toEqual([jwt])
+ })
+
+ it('parses session token from Set-Cookie line and strips cookie attributes', () => {
+ const raw =
+ '__Secure-next-auth.session-token.0=st-cookie-part-0; Domain=.chatgpt.com; Path=/; Expires=Thu, 28 May 2026 11:43:36 GMT; HttpOnly; Secure; SameSite=Lax'
+ const result = parseSoraRawTokens(raw)
+
+ expect(result.sessionTokens).toEqual(['st-cookie-part-0'])
+ expect(result.accessTokens).toEqual([])
+ })
+
+ it('merges chunked session-token cookies by numeric suffix order', () => {
+ const raw = [
+ 'Set-Cookie: __Secure-next-auth.session-token.1=part-1; Path=/; HttpOnly',
+ 'Set-Cookie: __Secure-next-auth.session-token.0=part-0; Path=/; HttpOnly'
+ ].join('\n')
+ const result = parseSoraRawTokens(raw)
+
+ expect(result.sessionTokens).toEqual(['part-0part-1'])
+ expect(result.accessTokens).toEqual([])
+ })
+
+ it('prefers latest duplicate chunk values when multiple cookie groups exist', () => {
+ const raw = [
+ 'Set-Cookie: __Secure-next-auth.session-token.0=old-0; Path=/; HttpOnly',
+ 'Set-Cookie: __Secure-next-auth.session-token.1=old-1; Path=/; HttpOnly',
+ 'Set-Cookie: __Secure-next-auth.session-token.0=new-0; Path=/; HttpOnly',
+ 'Set-Cookie: __Secure-next-auth.session-token.1=new-1; Path=/; HttpOnly'
+ ].join('\n')
+ const result = parseSoraRawTokens(raw)
+
+ expect(result.sessionTokens).toEqual(['new-0new-1'])
+ expect(result.accessTokens).toEqual([])
+ })
+
+ it('uses latest complete chunk group and ignores incomplete latest group', () => {
+ const raw = [
+ 'set-cookie',
+ '__Secure-next-auth.session-token.0=ok-0; Domain=.chatgpt.com; Path=/',
+ 'set-cookie',
+ '__Secure-next-auth.session-token.1=ok-1; Domain=.chatgpt.com; Path=/',
+ 'set-cookie',
+ '__Secure-next-auth.session-token.0=partial-0; Domain=.chatgpt.com; Path=/'
+ ].join('\n')
+
+ const result = parseSoraRawTokens(raw)
+
+ expect(result.sessionTokens).toEqual(['ok-0ok-1'])
+ expect(result.accessTokens).toEqual([])
+ })
+})
diff --git a/frontend/src/utils/openaiWsMode.ts b/frontend/src/utils/openaiWsMode.ts
new file mode 100644
index 00000000..b3e9cc00
--- /dev/null
+++ b/frontend/src/utils/openaiWsMode.ts
@@ -0,0 +1,61 @@
+export const OPENAI_WS_MODE_OFF = 'off'
+export const OPENAI_WS_MODE_SHARED = 'shared'
+export const OPENAI_WS_MODE_DEDICATED = 'dedicated'
+
+export type OpenAIWSMode =
+ | typeof OPENAI_WS_MODE_OFF
+ | typeof OPENAI_WS_MODE_SHARED
+ | typeof OPENAI_WS_MODE_DEDICATED
+
+const OPENAI_WS_MODES = new Set([
+ OPENAI_WS_MODE_OFF,
+ OPENAI_WS_MODE_SHARED,
+ OPENAI_WS_MODE_DEDICATED
+])
+
+export interface ResolveOpenAIWSModeOptions {
+ modeKey: string
+ enabledKey: string
+ fallbackEnabledKeys?: string[]
+ defaultMode?: OpenAIWSMode
+}
+
+export const normalizeOpenAIWSMode = (mode: unknown): OpenAIWSMode | null => {
+ if (typeof mode !== 'string') return null
+ const normalized = mode.trim().toLowerCase()
+ if (OPENAI_WS_MODES.has(normalized as OpenAIWSMode)) {
+ return normalized as OpenAIWSMode
+ }
+ return null
+}
+
+export const openAIWSModeFromEnabled = (enabled: unknown): OpenAIWSMode | null => {
+ if (typeof enabled !== 'boolean') return null
+ return enabled ? OPENAI_WS_MODE_SHARED : OPENAI_WS_MODE_OFF
+}
+
+export const isOpenAIWSModeEnabled = (mode: OpenAIWSMode): boolean => {
+ return mode !== OPENAI_WS_MODE_OFF
+}
+
+export const resolveOpenAIWSModeFromExtra = (
+ extra: Record | null | undefined,
+ options: ResolveOpenAIWSModeOptions
+): OpenAIWSMode => {
+ const fallback = options.defaultMode ?? OPENAI_WS_MODE_OFF
+ if (!extra) return fallback
+
+ const mode = normalizeOpenAIWSMode(extra[options.modeKey])
+ if (mode) return mode
+
+ const enabledMode = openAIWSModeFromEnabled(extra[options.enabledKey])
+ if (enabledMode) return enabledMode
+
+ const fallbackKeys = options.fallbackEnabledKeys ?? []
+ for (const key of fallbackKeys) {
+ const modeFromFallbackKey = openAIWSModeFromEnabled(extra[key])
+ if (modeFromFallbackKey) return modeFromFallbackKey
+ }
+
+ return fallback
+}
diff --git a/frontend/src/utils/soraTokenParser.ts b/frontend/src/utils/soraTokenParser.ts
new file mode 100644
index 00000000..87e36649
--- /dev/null
+++ b/frontend/src/utils/soraTokenParser.ts
@@ -0,0 +1,308 @@
+export interface ParsedSoraTokens {
+ sessionTokens: string[]
+ accessTokens: string[]
+}
+
+const sessionKeyNames = new Set(['sessiontoken', 'session_token', 'st'])
+const accessKeyNames = new Set(['accesstoken', 'access_token', 'at'])
+
+const sessionRegexes = [
+ /\bsessionToken\b\s*:\s*["']([^"']+)["']/gi,
+ /\bsession_token\b\s*:\s*["']([^"']+)["']/gi
+]
+
+const accessRegexes = [
+ /\baccessToken\b\s*:\s*["']([^"']+)["']/gi,
+ /\baccess_token\b\s*:\s*["']([^"']+)["']/gi
+]
+
+const sessionCookieRegex =
+ /(?:^|[\n\r;])\s*(?:(?:set-cookie|cookie)\s*:\s*)?__Secure-(?:next-auth|authjs)\.session-token(?:\.(\d+))?=([^;\r\n]+)/gi
+
+interface SessionCookieChunk {
+ index: number
+ value: string
+}
+
+const ignoredPlainLines = new Set([
+ 'set-cookie',
+ 'cookie',
+ 'strict-transport-security',
+ 'vary',
+ 'x-content-type-options',
+ 'x-openai-proxy-wasm'
+])
+
+function sanitizeToken(raw: string): string {
+ return raw.trim().replace(/^["'`]+|["'`,;]+$/g, '')
+}
+
+function addUnique(list: string[], seen: Set, rawValue: string): void {
+ const token = sanitizeToken(rawValue)
+ if (!token || seen.has(token)) {
+ return
+ }
+ seen.add(token)
+ list.push(token)
+}
+
+function isLikelyJWT(token: string): boolean {
+ if (!token.startsWith('eyJ')) {
+ return false
+ }
+ return token.split('.').length === 3
+}
+
+function collectFromObject(
+ value: unknown,
+ sessionTokens: string[],
+ sessionSeen: Set,
+ accessTokens: string[],
+ accessSeen: Set
+): void {
+ if (Array.isArray(value)) {
+ for (const item of value) {
+ collectFromObject(item, sessionTokens, sessionSeen, accessTokens, accessSeen)
+ }
+ return
+ }
+ if (!value || typeof value !== 'object') {
+ return
+ }
+
+ for (const [key, fieldValue] of Object.entries(value as Record)) {
+ if (typeof fieldValue === 'string') {
+ const normalizedKey = key.toLowerCase()
+ if (sessionKeyNames.has(normalizedKey)) {
+ addUnique(sessionTokens, sessionSeen, fieldValue)
+ }
+ if (accessKeyNames.has(normalizedKey)) {
+ addUnique(accessTokens, accessSeen, fieldValue)
+ }
+ continue
+ }
+ collectFromObject(fieldValue, sessionTokens, sessionSeen, accessTokens, accessSeen)
+ }
+}
+
+function collectFromJSONString(
+ raw: string,
+ sessionTokens: string[],
+ sessionSeen: Set,
+ accessTokens: string[],
+ accessSeen: Set
+): void {
+ const trimmed = raw.trim()
+ if (!trimmed) {
+ return
+ }
+
+ const candidates = [trimmed]
+ const firstBrace = trimmed.indexOf('{')
+ const lastBrace = trimmed.lastIndexOf('}')
+ if (firstBrace >= 0 && lastBrace > firstBrace) {
+ candidates.push(trimmed.slice(firstBrace, lastBrace + 1))
+ }
+
+ for (const candidate of candidates) {
+ try {
+ const parsed = JSON.parse(candidate)
+ collectFromObject(parsed, sessionTokens, sessionSeen, accessTokens, accessSeen)
+ return
+ } catch {
+ // ignore and keep trying other candidates
+ }
+ }
+}
+
+function collectByRegex(
+ raw: string,
+ regexes: RegExp[],
+ tokens: string[],
+ seen: Set
+): void {
+ for (const regex of regexes) {
+ regex.lastIndex = 0
+ let match: RegExpExecArray | null
+ match = regex.exec(raw)
+ while (match) {
+ if (match[1]) {
+ addUnique(tokens, seen, match[1])
+ }
+ match = regex.exec(raw)
+ }
+ }
+}
+
+function collectFromSessionCookies(
+ raw: string,
+ sessionTokens: string[],
+ sessionSeen: Set
+): void {
+ const chunkMatches: SessionCookieChunk[] = []
+ const singleValues: string[] = []
+
+ sessionCookieRegex.lastIndex = 0
+ let match: RegExpExecArray | null
+ match = sessionCookieRegex.exec(raw)
+ while (match) {
+ const chunkIndex = match[1]
+ const rawValue = match[2]
+ const value = sanitizeToken(rawValue || '')
+ if (value) {
+ if (chunkIndex !== undefined && chunkIndex !== '') {
+ const idx = Number.parseInt(chunkIndex, 10)
+ if (Number.isInteger(idx) && idx >= 0) {
+ chunkMatches.push({ index: idx, value })
+ }
+ } else {
+ singleValues.push(value)
+ }
+ }
+ match = sessionCookieRegex.exec(raw)
+ }
+
+ const mergedChunkToken = mergeLatestChunkedSessionToken(chunkMatches)
+ if (mergedChunkToken) {
+ addUnique(sessionTokens, sessionSeen, mergedChunkToken)
+ }
+
+ for (const value of singleValues) {
+ addUnique(sessionTokens, sessionSeen, value)
+ }
+}
+
+function mergeChunkSegment(
+ chunks: SessionCookieChunk[],
+ requiredMaxIndex: number,
+ requireComplete: boolean
+): string {
+ if (chunks.length === 0) {
+ return ''
+ }
+
+ const byIndex = new Map()
+ for (const chunk of chunks) {
+ byIndex.set(chunk.index, chunk.value)
+ }
+
+ if (!byIndex.has(0)) {
+ return ''
+ }
+ if (requireComplete) {
+ for (let i = 0; i <= requiredMaxIndex; i++) {
+ if (!byIndex.has(i)) {
+ return ''
+ }
+ }
+ }
+
+ const orderedIndexes = Array.from(byIndex.keys()).sort((a, b) => a - b)
+ return orderedIndexes.map((idx) => byIndex.get(idx) || '').join('')
+}
+
+function mergeLatestChunkedSessionToken(chunks: SessionCookieChunk[]): string {
+ if (chunks.length === 0) {
+ return ''
+ }
+
+ const requiredMaxIndex = chunks.reduce((max, chunk) => Math.max(max, chunk.index), 0)
+
+ const groupStarts: number[] = []
+ chunks.forEach((chunk, idx) => {
+ if (chunk.index === 0) {
+ groupStarts.push(idx)
+ }
+ })
+
+ if (groupStarts.length === 0) {
+ return mergeChunkSegment(chunks, requiredMaxIndex, false)
+ }
+
+ for (let i = groupStarts.length - 1; i >= 0; i--) {
+ const start = groupStarts[i]
+ const end = i + 1 < groupStarts.length ? groupStarts[i + 1] : chunks.length
+ const merged = mergeChunkSegment(chunks.slice(start, end), requiredMaxIndex, true)
+ if (merged) {
+ return merged
+ }
+ }
+
+ return mergeChunkSegment(chunks, requiredMaxIndex, false)
+}
+
+function collectPlainLines(
+ raw: string,
+ sessionTokens: string[],
+ sessionSeen: Set,
+ accessTokens: string[],
+ accessSeen: Set
+): void {
+ const lines = raw
+ .split('\n')
+ .map((line) => line.trim())
+ .filter((line) => line.length > 0)
+
+ for (const line of lines) {
+ const normalized = line.toLowerCase()
+ if (ignoredPlainLines.has(normalized)) {
+ continue
+ }
+ if (/^__secure-(next-auth|authjs)\.session-token(\.\d+)?=/i.test(line)) {
+ continue
+ }
+ if (line.includes(';')) {
+ continue
+ }
+
+ if (/^[a-zA-Z_][a-zA-Z0-9_]*=/.test(line)) {
+ const parts = line.split('=', 2)
+ const key = parts[0]?.trim().toLowerCase()
+ const value = parts[1]?.trim() || ''
+ if (key && sessionKeyNames.has(key)) {
+ addUnique(sessionTokens, sessionSeen, value)
+ continue
+ }
+ if (key && accessKeyNames.has(key)) {
+ addUnique(accessTokens, accessSeen, value)
+ continue
+ }
+ }
+
+ if (line.includes('{') || line.includes('}') || line.includes(':') || /\s/.test(line)) {
+ continue
+ }
+
+ if (isLikelyJWT(line)) {
+ addUnique(accessTokens, accessSeen, line)
+ continue
+ }
+ addUnique(sessionTokens, sessionSeen, line)
+ }
+}
+
+export function parseSoraRawTokens(rawInput: string): ParsedSoraTokens {
+ const raw = rawInput.trim()
+ if (!raw) {
+ return {
+ sessionTokens: [],
+ accessTokens: []
+ }
+ }
+
+ const sessionTokens: string[] = []
+ const accessTokens: string[] = []
+ const sessionSeen = new Set()
+ const accessSeen = new Set()
+
+ collectFromJSONString(raw, sessionTokens, sessionSeen, accessTokens, accessSeen)
+ collectByRegex(raw, sessionRegexes, sessionTokens, sessionSeen)
+ collectByRegex(raw, accessRegexes, accessTokens, accessSeen)
+ collectFromSessionCookies(raw, sessionTokens, sessionSeen)
+ collectPlainLines(raw, sessionTokens, sessionSeen, accessTokens, accessSeen)
+
+ return {
+ sessionTokens,
+ accessTokens
+ }
+}
diff --git a/frontend/src/utils/usageRequestType.ts b/frontend/src/utils/usageRequestType.ts
new file mode 100644
index 00000000..bfdafb07
--- /dev/null
+++ b/frontend/src/utils/usageRequestType.ts
@@ -0,0 +1,33 @@
+import type { UsageRequestType } from '@/types'
+
+export interface UsageRequestTypeLike {
+ request_type?: string | null
+ stream?: boolean | null
+ openai_ws_mode?: boolean | null
+}
+
+const VALID_REQUEST_TYPES = new Set(['unknown', 'sync', 'stream', 'ws_v2'])
+
+export const isUsageRequestType = (value: unknown): value is UsageRequestType => {
+ return typeof value === 'string' && VALID_REQUEST_TYPES.has(value as UsageRequestType)
+}
+
+export const resolveUsageRequestType = (value: UsageRequestTypeLike): UsageRequestType => {
+ if (isUsageRequestType(value.request_type)) {
+ return value.request_type
+ }
+ if (value.openai_ws_mode) {
+ return 'ws_v2'
+ }
+ return value.stream ? 'stream' : 'sync'
+}
+
+export const requestTypeToLegacyStream = (requestType?: UsageRequestType | null): boolean | null | undefined => {
+ if (!requestType || requestType === 'unknown') {
+ return null
+ }
+ if (requestType === 'sync') {
+ return false
+ }
+ return true
+}
diff --git a/frontend/src/views/admin/AccountsView.vue b/frontend/src/views/admin/AccountsView.vue
index 236c6f54..defcd434 100644
--- a/frontend/src/views/admin/AccountsView.vue
+++ b/frontend/src/views/admin/AccountsView.vue
@@ -184,7 +184,11 @@
-
+
@@ -259,7 +263,7 @@
-
+
@@ -273,7 +277,7 @@
+
+
diff --git a/frontend/src/views/admin/GroupsView.vue b/frontend/src/views/admin/GroupsView.vue
index 4d6dccf6..aa0a49a7 100644
--- a/frontend/src/views/admin/GroupsView.vue
+++ b/frontend/src/views/admin/GroupsView.vue
@@ -459,7 +459,7 @@
step="0.001"
min="0"
class="input"
- placeholder="0.134"
+ placeholder="0.201"
/>
@@ -532,6 +532,23 @@
/>
+
+
+
+
+ GB
+
+
+ {{ t('admin.groups.soraPricing.storageQuotaHint') }}
+
+
@@ -1139,7 +1156,7 @@
step="0.001"
min="0"
class="input"
- placeholder="0.134"
+ placeholder="0.201"
/>
@@ -1212,6 +1229,23 @@
/>
+
+
+
+
+ GB
+
+
+ {{ t('admin.groups.soraPricing.storageQuotaHint') }}
+
+
@@ -1881,6 +1915,7 @@ const createForm = reactive({
sora_image_price_540: null as number | null,
sora_video_price_per_request: null as number | null,
sora_video_price_per_request_hd: null as number | null,
+ sora_storage_quota_gb: null as number | null,
// Claude Code 客户端限制(仅 anthropic 平台使用)
claude_code_only: false,
fallback_group_id: null as number | null,
@@ -2121,6 +2156,7 @@ const editForm = reactive({
sora_image_price_540: null as number | null,
sora_video_price_per_request: null as number | null,
sora_video_price_per_request_hd: null as number | null,
+ sora_storage_quota_gb: null as number | null,
// Claude Code 客户端限制(仅 anthropic 平台使用)
claude_code_only: false,
fallback_group_id: null as number | null,
@@ -2220,6 +2256,7 @@ const closeCreateModal = () => {
createForm.sora_image_price_540 = null
createForm.sora_video_price_per_request = null
createForm.sora_video_price_per_request_hd = null
+ createForm.sora_storage_quota_gb = null
createForm.claude_code_only = false
createForm.fallback_group_id = null
createForm.fallback_group_id_on_invalid_request = null
@@ -2237,8 +2274,10 @@ const handleCreateGroup = async () => {
submitting.value = true
try {
// 构建请求数据,包含模型路由配置
+ const { sora_storage_quota_gb: createQuotaGb, ...createRest } = createForm
const requestData = {
- ...createForm,
+ ...createRest,
+ sora_storage_quota_bytes: createQuotaGb ? Math.round(createQuotaGb * 1024 * 1024 * 1024) : 0,
model_routing: convertRoutingRulesToApiFormat(createModelRoutingRules.value)
}
await adminAPI.groups.create(requestData)
@@ -2277,6 +2316,7 @@ const handleEdit = async (group: AdminGroup) => {
editForm.sora_image_price_540 = group.sora_image_price_540
editForm.sora_video_price_per_request = group.sora_video_price_per_request
editForm.sora_video_price_per_request_hd = group.sora_video_price_per_request_hd
+ editForm.sora_storage_quota_gb = group.sora_storage_quota_bytes ? Number((group.sora_storage_quota_bytes / (1024 * 1024 * 1024)).toFixed(2)) : null
editForm.claude_code_only = group.claude_code_only || false
editForm.fallback_group_id = group.fallback_group_id
editForm.fallback_group_id_on_invalid_request = group.fallback_group_id_on_invalid_request
@@ -2310,8 +2350,10 @@ const handleUpdateGroup = async () => {
submitting.value = true
try {
// 转换 fallback_group_id: null -> 0 (后端使用 0 表示清除)
+ const { sora_storage_quota_gb: editQuotaGb, ...editRest } = editForm
const payload = {
- ...editForm,
+ ...editRest,
+ sora_storage_quota_bytes: editQuotaGb ? Math.round(editQuotaGb * 1024 * 1024 * 1024) : 0,
fallback_group_id: editForm.fallback_group_id === null ? 0 : editForm.fallback_group_id,
fallback_group_id_on_invalid_request:
editForm.fallback_group_id_on_invalid_request === null
diff --git a/frontend/src/views/admin/ProxiesView.vue b/frontend/src/views/admin/ProxiesView.vue
index 23d73109..147b3205 100644
--- a/frontend/src/views/admin/ProxiesView.vue
+++ b/frontend/src/views/admin/ProxiesView.vue
@@ -124,7 +124,54 @@
- {{ row.host }}:{{ row.port }}
+
+
{{ row.host }}:{{ row.port }}
+
+
+
+
+
+
+
+
+
+
+
+
+
+ {{ row.username }}
+
+ {{ visiblePasswordIds.has(row.id) ? row.password : '••••••' }}
+
+
+
+
+ -
@@ -397,12 +444,21 @@
@@ -581,12 +637,22 @@
@@ -813,15 +879,18 @@ import ImportDataModal from '@/components/admin/proxy/ImportDataModal.vue'
import Select from '@/components/common/Select.vue'
import Icon from '@/components/icons/Icon.vue'
import PlatformTypeBadge from '@/components/common/PlatformTypeBadge.vue'
+import { useClipboard } from '@/composables/useClipboard'
const { t } = useI18n()
const appStore = useAppStore()
+const { copyToClipboard } = useClipboard()
const columns = computed
(() => [
{ key: 'select', label: '', sortable: false },
{ key: 'name', label: t('admin.proxies.columns.name'), sortable: true },
{ key: 'protocol', label: t('admin.proxies.columns.protocol'), sortable: true },
{ key: 'address', label: t('admin.proxies.columns.address'), sortable: false },
+ { key: 'auth', label: t('admin.proxies.columns.auth'), sortable: false },
{ key: 'location', label: t('admin.proxies.columns.location'), sortable: false },
{ key: 'account_count', label: t('admin.proxies.columns.accounts'), sortable: true },
{ key: 'latency', label: t('admin.proxies.columns.latency'), sortable: false },
@@ -858,6 +927,8 @@ const editStatusOptions = computed(() => [
])
const proxies = ref([])
+const visiblePasswordIds = reactive(new Set())
+const copyMenuProxyId = ref(null)
const loading = ref(false)
const searchQuery = ref('')
const filters = reactive({
@@ -872,7 +943,10 @@ const pagination = reactive({
})
const showCreateModal = ref(false)
+const createPasswordVisible = ref(false)
const showEditModal = ref(false)
+const editPasswordVisible = ref(false)
+const editPasswordDirty = ref(false)
const showImportData = ref(false)
const showDeleteDialog = ref(false)
const showBatchDeleteDialog = ref(false)
@@ -1030,6 +1104,7 @@ const closeCreateModal = () => {
createForm.port = 8080
createForm.username = ''
createForm.password = ''
+ createPasswordVisible.value = false
batchInput.value = ''
batchParseResult.total = 0
batchParseResult.valid = 0
@@ -1173,14 +1248,18 @@ const handleEdit = (proxy: Proxy) => {
editForm.host = proxy.host
editForm.port = proxy.port
editForm.username = proxy.username || ''
- editForm.password = ''
+ editForm.password = proxy.password || ''
editForm.status = proxy.status
+ editPasswordVisible.value = false
+ editPasswordDirty.value = false
showEditModal.value = true
}
const closeEditModal = () => {
showEditModal.value = false
editingProxy.value = null
+ editPasswordVisible.value = false
+ editPasswordDirty.value = false
}
const handleUpdateProxy = async () => {
@@ -1209,10 +1288,9 @@ const handleUpdateProxy = async () => {
status: editForm.status
}
- // Only include password if it was changed
- const trimmedPassword = editForm.password.trim()
- if (trimmedPassword) {
- updateData.password = trimmedPassword
+ // Only include password if user actually modified the field
+ if (editPasswordDirty.value) {
+ updateData.password = editForm.password.trim() || null
}
await adminAPI.proxies.update(editingProxy.value.id, updateData)
@@ -1715,12 +1793,60 @@ const closeAccountsModal = () => {
proxyAccounts.value = []
}
+// ── Proxy URL copy ──
+function buildAuthPart(row: any): string {
+ const user = row.username ? encodeURIComponent(row.username) : ''
+ const pass = row.password ? encodeURIComponent(row.password) : ''
+ if (user && pass) return `${user}:${pass}@`
+ if (user) return `${user}@`
+ if (pass) return `:${pass}@`
+ return ''
+}
+
+function buildProxyUrl(row: any): string {
+ return `${row.protocol}://${buildAuthPart(row)}${row.host}:${row.port}`
+}
+
+function getCopyFormats(row: any) {
+ const hasAuth = row.username || row.password
+ const fullUrl = buildProxyUrl(row)
+ const formats = [
+ { label: fullUrl, value: fullUrl },
+ ]
+ if (hasAuth) {
+ const withoutProtocol = fullUrl.replace(/^[^:]+:\/\//, '')
+ formats.push({ label: withoutProtocol, value: withoutProtocol })
+ }
+ formats.push({ label: `${row.host}:${row.port}`, value: `${row.host}:${row.port}` })
+ return formats
+}
+
+function copyProxyUrl(row: any) {
+ copyToClipboard(buildProxyUrl(row), t('admin.proxies.urlCopied'))
+ copyMenuProxyId.value = null
+}
+
+function toggleCopyMenu(id: number) {
+ copyMenuProxyId.value = copyMenuProxyId.value === id ? null : id
+}
+
+function copyFormat(value: string) {
+ copyToClipboard(value, t('admin.proxies.urlCopied'))
+ copyMenuProxyId.value = null
+}
+
+function closeCopyMenu() {
+ copyMenuProxyId.value = null
+}
+
onMounted(() => {
loadProxies()
+ document.addEventListener('click', closeCopyMenu)
})
onUnmounted(() => {
clearTimeout(searchTimeout)
abortController?.abort()
+ document.removeEventListener('click', closeCopyMenu)
})
diff --git a/frontend/src/views/admin/SettingsView.vue b/frontend/src/views/admin/SettingsView.vue
index e7c886c6..acbe9cf5 100644
--- a/frontend/src/views/admin/SettingsView.vue
+++ b/frontend/src/views/admin/SettingsView.vue
@@ -579,7 +579,7 @@
{{ t('admin.settings.defaults.description') }}
-
+
+
+
+
+
+
+
+ {{ t('admin.settings.defaults.defaultSubscriptionsHint') }}
+
+
+
+
+
+
+ {{ t('admin.settings.defaults.defaultSubscriptionsEmpty') }}
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ {{ t('admin.settings.claudeCode.title') }}
+
+
+ {{ t('admin.settings.claudeCode.description') }}
+
+
+
+
+
+
+
+ {{ t('admin.settings.claudeCode.minVersionHint') }}
+
+
@@ -991,6 +1112,51 @@
{{ t('admin.settings.purchase.iframeWarning') }}
+
+
+
+
+
+ {{ t('admin.settings.purchase.integrationDoc') }}
+
+
—
+
+ {{ t('admin.settings.purchase.integrationDocHint') }}
+
+
+
+
+
+
+
+
+
+ {{ t('admin.settings.soraClient.title') }}
+
+
+ {{ t('admin.settings.soraClient.description') }}
+
+
+
+
+
+
+
+ {{ t('admin.settings.soraClient.enabledHint') }}
+
+
+
+
@@ -1083,9 +1249,17 @@
import { ref, reactive, computed, onMounted } from 'vue'
import { useI18n } from 'vue-i18n'
import { adminAPI } from '@/api'
-import type { SystemSettings, UpdateSettingsRequest } from '@/api/admin/settings'
+import type {
+ SystemSettings,
+ UpdateSettingsRequest,
+ DefaultSubscriptionSetting
+} from '@/api/admin/settings'
+import type { AdminGroup } from '@/types'
import AppLayout from '@/components/layout/AppLayout.vue'
import Icon from '@/components/icons/Icon.vue'
+import Select from '@/components/common/Select.vue'
+import GroupBadge from '@/components/common/GroupBadge.vue'
+import GroupOptionItem from '@/components/common/GroupOptionItem.vue'
import Toggle from '@/components/common/Toggle.vue'
import { useClipboard } from '@/composables/useClipboard'
import { useAppStore } from '@/stores'
@@ -1107,6 +1281,7 @@ const adminApiKeyExists = ref(false)
const adminApiKeyMasked = ref('')
const adminApiKeyOperating = ref(false)
const newAdminApiKey = ref('')
+const subscriptionGroups = ref([])
// Stream Timeout 状态
const streamTimeoutLoading = ref(true)
@@ -1119,6 +1294,16 @@ const streamTimeoutForm = reactive({
threshold_window_minutes: 10
})
+interface DefaultSubscriptionGroupOption {
+ value: number
+ label: string
+ description: string | null
+ platform: AdminGroup['platform']
+ subscriptionType: AdminGroup['subscription_type']
+ rate: number
+ [key: string]: unknown
+}
+
type SettingsForm = SystemSettings & {
smtp_password: string
turnstile_secret_key: string
@@ -1135,6 +1320,7 @@ const form = reactive({
totp_encryption_key_configured: false,
default_balance: 0,
default_concurrency: 1,
+ default_subscriptions: [],
site_name: 'TianShuAPI',
site_logo: '',
site_subtitle: 'Subscription to API Conversion Platform',
@@ -1145,6 +1331,7 @@ const form = reactive({
hide_ccs_import_button: false,
purchase_subscription_enabled: false,
purchase_subscription_url: '',
+ sora_client_enabled: false,
smtp_host: '',
smtp_port: 587,
smtp_username: '',
@@ -1177,9 +1364,22 @@ const form = reactive({
ops_monitoring_enabled: true,
ops_realtime_monitoring_enabled: true,
ops_query_mode_default: 'auto',
- ops_metrics_interval_seconds: 60
+ ops_metrics_interval_seconds: 60,
+ // Claude Code version check
+ min_claude_code_version: ''
})
+const defaultSubscriptionGroupOptions = computed(() =>
+ subscriptionGroups.value.map((group) => ({
+ value: group.id,
+ label: group.name,
+ description: group.description,
+ platform: group.platform,
+ subscriptionType: group.subscription_type,
+ rate: group.rate_multiplier
+ }))
+)
+
// LinuxDo OAuth redirect URL suggestion
const linuxdoRedirectUrlSuggestion = computed(() => {
if (typeof window === 'undefined') return ''
@@ -1239,6 +1439,14 @@ async function loadSettings() {
try {
const settings = await adminAPI.settings.getSettings()
Object.assign(form, settings)
+ form.default_subscriptions = Array.isArray(settings.default_subscriptions)
+ ? settings.default_subscriptions
+ .filter((item) => item.group_id > 0 && item.validity_days > 0)
+ .map((item) => ({
+ group_id: item.group_id,
+ validity_days: item.validity_days
+ }))
+ : []
form.smtp_password = ''
form.turnstile_secret_key = ''
form.linuxdo_connect_client_secret = ''
@@ -1251,9 +1459,60 @@ async function loadSettings() {
}
}
+async function loadSubscriptionGroups() {
+ try {
+ const groups = await adminAPI.groups.getAll()
+ subscriptionGroups.value = groups.filter(
+ (group) => group.subscription_type === 'subscription' && group.status === 'active'
+ )
+ } catch (error) {
+ console.error('Failed to load subscription groups:', error)
+ subscriptionGroups.value = []
+ }
+}
+
+function addDefaultSubscription() {
+ if (subscriptionGroups.value.length === 0) return
+ const existing = new Set(form.default_subscriptions.map((item) => item.group_id))
+ const candidate = subscriptionGroups.value.find((group) => !existing.has(group.id))
+ if (!candidate) return
+ form.default_subscriptions.push({
+ group_id: candidate.id,
+ validity_days: 30
+ })
+}
+
+function removeDefaultSubscription(index: number) {
+ form.default_subscriptions.splice(index, 1)
+}
+
async function saveSettings() {
saving.value = true
try {
+ const normalizedDefaultSubscriptions = form.default_subscriptions
+ .filter((item) => item.group_id > 0 && item.validity_days > 0)
+ .map((item: DefaultSubscriptionSetting) => ({
+ group_id: item.group_id,
+ validity_days: Math.min(36500, Math.max(1, Math.floor(item.validity_days)))
+ }))
+
+ const seenGroupIDs = new Set()
+ const duplicateDefaultSubscription = normalizedDefaultSubscriptions.find((item) => {
+ if (seenGroupIDs.has(item.group_id)) {
+ return true
+ }
+ seenGroupIDs.add(item.group_id)
+ return false
+ })
+ if (duplicateDefaultSubscription) {
+ appStore.showError(
+ t('admin.settings.defaults.defaultSubscriptionsDuplicate', {
+ groupId: duplicateDefaultSubscription.group_id
+ })
+ )
+ return
+ }
+
const payload: UpdateSettingsRequest = {
registration_enabled: form.registration_enabled,
email_verify_enabled: form.email_verify_enabled,
@@ -1263,6 +1522,7 @@ async function saveSettings() {
totp_enabled: form.totp_enabled,
default_balance: form.default_balance,
default_concurrency: form.default_concurrency,
+ default_subscriptions: normalizedDefaultSubscriptions,
site_name: form.site_name,
site_logo: form.site_logo,
site_subtitle: form.site_subtitle,
@@ -1273,6 +1533,7 @@ async function saveSettings() {
hide_ccs_import_button: form.hide_ccs_import_button,
purchase_subscription_enabled: form.purchase_subscription_enabled,
purchase_subscription_url: form.purchase_subscription_url,
+ sora_client_enabled: form.sora_client_enabled,
smtp_host: form.smtp_host,
smtp_port: form.smtp_port,
smtp_username: form.smtp_username,
@@ -1293,7 +1554,8 @@ async function saveSettings() {
fallback_model_gemini: form.fallback_model_gemini,
fallback_model_antigravity: form.fallback_model_antigravity,
enable_identity_patch: form.enable_identity_patch,
- identity_patch_prompt: form.identity_patch_prompt
+ identity_patch_prompt: form.identity_patch_prompt,
+ min_claude_code_version: form.min_claude_code_version
}
const updated = await adminAPI.settings.updateSettings(payload)
Object.assign(form, updated)
@@ -1459,7 +1721,18 @@ async function saveStreamTimeoutSettings() {
onMounted(() => {
loadSettings()
+ loadSubscriptionGroups()
loadAdminApiKey()
loadStreamTimeoutSettings()
})
+
+
diff --git a/frontend/src/views/admin/UsageView.vue b/frontend/src/views/admin/UsageView.vue
index dbc81f3a..b5aa63c8 100644
--- a/frontend/src/views/admin/UsageView.vue
+++ b/frontend/src/views/admin/UsageView.vue
@@ -14,11 +14,47 @@
-
+
+
-
-
+
+
+
+
+
+
+
+
+
+
+
@@ -38,17 +74,19 @@ import { useI18n } from 'vue-i18n'
import { saveAs } from 'file-saver'
import { useAppStore } from '@/stores/app'; import { adminAPI } from '@/api/admin'; import { adminUsageAPI } from '@/api/admin/usage'
import { formatReasoningEffort } from '@/utils/format'
+import { resolveUsageRequestType, requestTypeToLegacyStream } from '@/utils/usageRequestType'
import AppLayout from '@/components/layout/AppLayout.vue'; import Pagination from '@/components/common/Pagination.vue'; import Select from '@/components/common/Select.vue'
import UsageStatsCards from '@/components/admin/usage/UsageStatsCards.vue'; import UsageFilters from '@/components/admin/usage/UsageFilters.vue'
import UsageTable from '@/components/admin/usage/UsageTable.vue'; import UsageExportProgress from '@/components/admin/usage/UsageExportProgress.vue'
import UsageCleanupDialog from '@/components/admin/usage/UsageCleanupDialog.vue'
-import ModelDistributionChart from '@/components/charts/ModelDistributionChart.vue'; import TokenUsageTrend from '@/components/charts/TokenUsageTrend.vue'
-import type { AdminUsageLog, TrendDataPoint, ModelStat } from '@/types'; import type { AdminUsageStatsResponse, AdminUsageQueryParams } from '@/api/admin/usage'
+import ModelDistributionChart from '@/components/charts/ModelDistributionChart.vue'; import GroupDistributionChart from '@/components/charts/GroupDistributionChart.vue'; import TokenUsageTrend from '@/components/charts/TokenUsageTrend.vue'
+import Icon from '@/components/icons/Icon.vue'
+import type { AdminUsageLog, TrendDataPoint, ModelStat, GroupStat } from '@/types'; import type { AdminUsageStatsResponse, AdminUsageQueryParams } from '@/api/admin/usage'
const { t } = useI18n()
const appStore = useAppStore()
const usageStats = ref(null); const usageLogs = ref([]); const loading = ref(false); const exporting = ref(false)
-const trendData = ref([]); const modelStats = ref([]); const chartsLoading = ref(false); const granularity = ref<'day' | 'hour'>('day')
+const trendData = ref([]); const modelStats = ref([]); const groupStats = ref([]); const chartsLoading = ref(false); const granularity = ref<'day' | 'hour'>('day')
let abortController: AbortController | null = null; let exportAbortController: AbortController | null = null
const exportProgress = reactive({ show: false, progress: 0, current: 0, total: 0, estimatedTime: '' })
const cleanupDialogVisible = ref(false)
@@ -63,32 +101,53 @@ const formatLD = (d: Date) => {
}
const now = new Date(); const weekAgo = new Date(); weekAgo.setDate(weekAgo.getDate() - 6)
const startDate = ref(formatLD(weekAgo)); const endDate = ref(formatLD(now))
-const filters = ref({ user_id: undefined, model: undefined, group_id: undefined, billing_type: null, start_date: startDate.value, end_date: endDate.value })
+const filters = ref({ user_id: undefined, model: undefined, group_id: undefined, request_type: undefined, billing_type: null, start_date: startDate.value, end_date: endDate.value })
const pagination = reactive({ page: 1, page_size: 20, total: 0 })
const loadLogs = async () => {
abortController?.abort(); const c = new AbortController(); abortController = c; loading.value = true
try {
- const res = await adminAPI.usage.list({ page: pagination.page, page_size: pagination.page_size, ...filters.value }, { signal: c.signal })
+ const requestType = filters.value.request_type
+ const legacyStream = requestType ? requestTypeToLegacyStream(requestType) : filters.value.stream
+ const res = await adminAPI.usage.list({ page: pagination.page, page_size: pagination.page_size, ...filters.value, stream: legacyStream === null ? undefined : legacyStream }, { signal: c.signal })
if(!c.signal.aborted) { usageLogs.value = res.items; pagination.total = res.total }
} catch (error: any) { if(error?.name !== 'AbortError') console.error('Failed to load usage logs:', error) } finally { if(abortController === c) loading.value = false }
}
-const loadStats = async () => { try { const s = await adminAPI.usage.getStats(filters.value); usageStats.value = s } catch (error) { console.error('Failed to load usage stats:', error) } }
+const loadStats = async () => {
+ try {
+ const requestType = filters.value.request_type
+ const legacyStream = requestType ? requestTypeToLegacyStream(requestType) : filters.value.stream
+ const s = await adminAPI.usage.getStats({ ...filters.value, stream: legacyStream === null ? undefined : legacyStream })
+ usageStats.value = s
+ } catch (error) {
+ console.error('Failed to load usage stats:', error)
+ }
+}
const loadChartData = async () => {
chartsLoading.value = true
try {
- const params = { start_date: filters.value.start_date || startDate.value, end_date: filters.value.end_date || endDate.value, granularity: granularity.value, user_id: filters.value.user_id, model: filters.value.model, api_key_id: filters.value.api_key_id, account_id: filters.value.account_id, group_id: filters.value.group_id, stream: filters.value.stream, billing_type: filters.value.billing_type }
- const [trendRes, modelRes] = await Promise.all([adminAPI.dashboard.getUsageTrend(params), adminAPI.dashboard.getModelStats({ start_date: params.start_date, end_date: params.end_date, user_id: params.user_id, model: params.model, api_key_id: params.api_key_id, account_id: params.account_id, group_id: params.group_id, stream: params.stream, billing_type: params.billing_type })])
- trendData.value = trendRes.trend || []; modelStats.value = modelRes.models || []
+ const requestType = filters.value.request_type
+ const legacyStream = requestType ? requestTypeToLegacyStream(requestType) : filters.value.stream
+ const params = { start_date: filters.value.start_date || startDate.value, end_date: filters.value.end_date || endDate.value, granularity: granularity.value, user_id: filters.value.user_id, model: filters.value.model, api_key_id: filters.value.api_key_id, account_id: filters.value.account_id, group_id: filters.value.group_id, request_type: requestType, stream: legacyStream === null ? undefined : legacyStream, billing_type: filters.value.billing_type }
+ const statsParams = { start_date: params.start_date, end_date: params.end_date, user_id: params.user_id, model: params.model, api_key_id: params.api_key_id, account_id: params.account_id, group_id: params.group_id, request_type: params.request_type, stream: params.stream, billing_type: params.billing_type }
+ const [trendRes, modelRes, groupRes] = await Promise.all([adminAPI.dashboard.getUsageTrend(params), adminAPI.dashboard.getModelStats(statsParams), adminAPI.dashboard.getGroupStats(statsParams)])
+ trendData.value = trendRes.trend || []; modelStats.value = modelRes.models || []; groupStats.value = groupRes.groups || []
} catch (error) { console.error('Failed to load chart data:', error) } finally { chartsLoading.value = false }
}
const applyFilters = () => { pagination.page = 1; loadLogs(); loadStats(); loadChartData() }
const refreshData = () => { loadLogs(); loadStats(); loadChartData() }
-const resetFilters = () => { startDate.value = formatLD(weekAgo); endDate.value = formatLD(now); filters.value = { start_date: startDate.value, end_date: endDate.value, billing_type: null }; granularity.value = 'day'; applyFilters() }
+const resetFilters = () => { startDate.value = formatLD(weekAgo); endDate.value = formatLD(now); filters.value = { start_date: startDate.value, end_date: endDate.value, request_type: undefined, billing_type: null }; granularity.value = 'day'; applyFilters() }
const handlePageChange = (p: number) => { pagination.page = p; loadLogs() }
const handlePageSizeChange = (s: number) => { pagination.page_size = s; pagination.page = 1; loadLogs() }
const cancelExport = () => exportAbortController?.abort()
const openCleanupDialog = () => { cleanupDialogVisible.value = true }
+const getRequestTypeLabel = (log: AdminUsageLog): string => {
+ const requestType = resolveUsageRequestType(log)
+ if (requestType === 'ws_v2') return t('usage.ws')
+ if (requestType === 'stream') return t('usage.stream')
+ if (requestType === 'sync') return t('usage.sync')
+ return t('usage.unknown')
+}
const exportToExcel = async () => {
if (exporting.value) return; exporting.value = true; exportProgress.show = true
@@ -110,11 +169,13 @@ const exportToExcel = async () => {
]
const ws = XLSX.utils.aoa_to_sheet([headers])
while (true) {
- const res = await adminUsageAPI.list({ page: p, page_size: 100, ...filters.value }, { signal: c.signal })
+ const requestType = filters.value.request_type
+ const legacyStream = requestType ? requestTypeToLegacyStream(requestType) : filters.value.stream
+ const res = await adminUsageAPI.list({ page: p, page_size: 100, ...filters.value, stream: legacyStream === null ? undefined : legacyStream }, { signal: c.signal })
if (c.signal.aborted) break; if (p === 1) { total = res.total; exportProgress.total = total }
const rows = (res.items || []).map((log: AdminUsageLog) => [
log.created_at, log.user?.email || '', log.api_key?.name || '', log.account?.name || '', log.model,
- formatReasoningEffort(log.reasoning_effort), log.group?.name || '', log.stream ? t('usage.stream') : t('usage.sync'),
+ formatReasoningEffort(log.reasoning_effort), log.group?.name || '', getRequestTypeLabel(log),
log.input_tokens, log.output_tokens, log.cache_read_tokens, log.cache_creation_tokens,
log.input_cost?.toFixed(6) || '0.000000', log.output_cost?.toFixed(6) || '0.000000',
log.cache_read_cost?.toFixed(6) || '0.000000', log.cache_creation_cost?.toFixed(6) || '0.000000',
@@ -141,6 +202,77 @@ const exportToExcel = async () => {
finally { if(exportAbortController === c) { exportAbortController = null; exporting.value = false; exportProgress.show = false } }
}
-onMounted(() => { loadLogs(); loadStats(); loadChartData() })
-onUnmounted(() => { abortController?.abort(); exportAbortController?.abort() })
+// Column visibility
+const ALWAYS_VISIBLE = ['user', 'created_at']
+const DEFAULT_HIDDEN_COLUMNS = ['reasoning_effort', 'user_agent']
+const HIDDEN_COLUMNS_KEY = 'usage-hidden-columns'
+
+const allColumns = computed(() => [
+ { key: 'user', label: t('admin.usage.user'), sortable: false },
+ { key: 'api_key', label: t('usage.apiKeyFilter'), sortable: false },
+ { key: 'account', label: t('admin.usage.account'), sortable: false },
+ { key: 'model', label: t('usage.model'), sortable: true },
+ { key: 'reasoning_effort', label: t('usage.reasoningEffort'), sortable: false },
+ { key: 'group', label: t('admin.usage.group'), sortable: false },
+ { key: 'stream', label: t('usage.type'), sortable: false },
+ { key: 'tokens', label: t('usage.tokens'), sortable: false },
+ { key: 'cost', label: t('usage.cost'), sortable: false },
+ { key: 'first_token', label: t('usage.firstToken'), sortable: false },
+ { key: 'duration', label: t('usage.duration'), sortable: false },
+ { key: 'created_at', label: t('usage.time'), sortable: true },
+ { key: 'user_agent', label: t('usage.userAgent'), sortable: false },
+ { key: 'ip_address', label: t('admin.usage.ipAddress'), sortable: false }
+])
+
+const hiddenColumns = reactive>(new Set())
+
+const toggleableColumns = computed(() =>
+ allColumns.value.filter(col => !ALWAYS_VISIBLE.includes(col.key))
+)
+
+const visibleColumns = computed(() =>
+ allColumns.value.filter(col =>
+ ALWAYS_VISIBLE.includes(col.key) || !hiddenColumns.has(col.key)
+ )
+)
+
+const isColumnVisible = (key: string) => !hiddenColumns.has(key)
+
+const toggleColumn = (key: string) => {
+ if (hiddenColumns.has(key)) {
+ hiddenColumns.delete(key)
+ } else {
+ hiddenColumns.add(key)
+ }
+ try {
+ localStorage.setItem(HIDDEN_COLUMNS_KEY, JSON.stringify([...hiddenColumns]))
+ } catch (e) {
+ console.error('Failed to save columns:', e)
+ }
+}
+
+const loadSavedColumns = () => {
+ try {
+ const saved = localStorage.getItem(HIDDEN_COLUMNS_KEY)
+ if (saved) {
+ (JSON.parse(saved) as string[]).forEach(key => hiddenColumns.add(key))
+ } else {
+ DEFAULT_HIDDEN_COLUMNS.forEach(key => hiddenColumns.add(key))
+ }
+ } catch {
+ DEFAULT_HIDDEN_COLUMNS.forEach(key => hiddenColumns.add(key))
+ }
+}
+
+const showColumnDropdown = ref(false)
+const columnDropdownRef = ref(null)
+
+const handleColumnClickOutside = (event: MouseEvent) => {
+ if (columnDropdownRef.value && !columnDropdownRef.value.contains(event.target as HTMLElement)) {
+ showColumnDropdown.value = false
+ }
+}
+
+onMounted(() => { loadLogs(); loadStats(); loadChartData(); loadSavedColumns(); document.addEventListener('click', handleColumnClickOutside) })
+onUnmounted(() => { abortController?.abort(); exportAbortController?.abort(); document.removeEventListener('click', handleColumnClickOutside) })
diff --git a/frontend/src/views/user/PurchaseSubscriptionView.vue b/frontend/src/views/user/PurchaseSubscriptionView.vue
index 55bcf307..fdcd0d34 100644
--- a/frontend/src/views/user/PurchaseSubscriptionView.vue
+++ b/frontend/src/views/user/PurchaseSubscriptionView.vue
@@ -1,30 +1,6 @@
-
-
-
- {{ t('purchase.title') }}
-
-
- {{ t('purchase.description') }}
-
-
-
-
-
-
-
diff --git a/frontend/src/views/user/SoraView.vue b/frontend/src/views/user/SoraView.vue
new file mode 100644
index 00000000..0ebea5b0
--- /dev/null
+++ b/frontend/src/views/user/SoraView.vue
@@ -0,0 +1,369 @@
+
+
+
+
+
+
+
+
{{ t('sora.notEnabled') }}
+
{{ t('sora.notEnabledDesc') }}
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/frontend/src/views/user/UsageView.vue b/frontend/src/views/user/UsageView.vue
index 53a11702..ff875325 100644
--- a/frontend/src/views/user/UsageView.vue
+++ b/frontend/src/views/user/UsageView.vue
@@ -166,13 +166,9 @@
- {{ row.stream ? t('usage.stream') : t('usage.sync') }}
+ {{ getRequestTypeLabel(row) }}
@@ -473,12 +469,13 @@ import TablePageLayout from '@/components/layout/TablePageLayout.vue'
import DataTable from '@/components/common/DataTable.vue'
import Pagination from '@/components/common/Pagination.vue'
import EmptyState from '@/components/common/EmptyState.vue'
- import Select from '@/components/common/Select.vue'
- import DateRangePicker from '@/components/common/DateRangePicker.vue'
- import Icon from '@/components/icons/Icon.vue'
- import type { UsageLog, ApiKey, UsageQueryParams, UsageStatsResponse } from '@/types'
- import type { Column } from '@/components/common/types'
- import { formatDateTime, formatReasoningEffort } from '@/utils/format'
+import Select from '@/components/common/Select.vue'
+import DateRangePicker from '@/components/common/DateRangePicker.vue'
+import Icon from '@/components/icons/Icon.vue'
+import type { UsageLog, ApiKey, UsageQueryParams, UsageStatsResponse } from '@/types'
+import type { Column } from '@/components/common/types'
+import { formatDateTime, formatReasoningEffort } from '@/utils/format'
+import { resolveUsageRequestType } from '@/utils/usageRequestType'
const { t } = useI18n()
const appStore = useAppStore()
@@ -577,6 +574,30 @@ const formatUserAgent = (ua: string): string => {
return ua
}
+const getRequestTypeLabel = (log: UsageLog): string => {
+ const requestType = resolveUsageRequestType(log)
+ if (requestType === 'ws_v2') return t('usage.ws')
+ if (requestType === 'stream') return t('usage.stream')
+ if (requestType === 'sync') return t('usage.sync')
+ return t('usage.unknown')
+}
+
+const getRequestTypeBadgeClass = (log: UsageLog): string => {
+ const requestType = resolveUsageRequestType(log)
+ if (requestType === 'ws_v2') return 'bg-violet-100 text-violet-800 dark:bg-violet-900 dark:text-violet-200'
+ if (requestType === 'stream') return 'bg-blue-100 text-blue-800 dark:bg-blue-900 dark:text-blue-200'
+ if (requestType === 'sync') return 'bg-gray-100 text-gray-800 dark:bg-gray-700 dark:text-gray-200'
+ return 'bg-amber-100 text-amber-800 dark:bg-amber-900 dark:text-amber-200'
+}
+
+const getRequestTypeExportText = (log: UsageLog): string => {
+ const requestType = resolveUsageRequestType(log)
+ if (requestType === 'ws_v2') return 'WS'
+ if (requestType === 'stream') return 'Stream'
+ if (requestType === 'sync') return 'Sync'
+ return 'Unknown'
+}
+
const formatTokens = (value: number): string => {
if (value >= 1_000_000_000) {
return `${(value / 1_000_000_000).toFixed(2)}B`
@@ -768,7 +789,7 @@ const exportToCSV = async () => {
log.api_key?.name || '',
log.model,
formatReasoningEffort(log.reasoning_effort),
- log.stream ? 'Stream' : 'Sync',
+ getRequestTypeExportText(log),
log.input_tokens,
log.output_tokens,
log.cache_read_tokens,
diff --git a/frontend/vitest.config.ts b/frontend/vitest.config.ts
index 2ff23c77..39568250 100644
--- a/frontend/vitest.config.ts
+++ b/frontend/vitest.config.ts
@@ -1,7 +1,9 @@
import { defineConfig } from 'vitest/config'
import { resolve } from 'path'
+import vue from '@vitejs/plugin-vue'
export default defineConfig({
+ plugins: [vue()],
resolve: {
alias: {
'@': resolve(__dirname, 'src'),
diff --git a/openspec/config.yaml b/openspec/config.yaml
new file mode 100644
index 00000000..392946c6
--- /dev/null
+++ b/openspec/config.yaml
@@ -0,0 +1,20 @@
+schema: spec-driven
+
+# Project context (optional)
+# This is shown to AI when creating artifacts.
+# Add your tech stack, conventions, style guides, domain knowledge, etc.
+# Example:
+# context: |
+# Tech stack: TypeScript, React, Node.js
+# We use conventional commits
+# Domain: e-commerce platform
+
+# Per-artifact rules (optional)
+# Add custom rules for specific artifacts.
+# Example:
+# rules:
+# proposal:
+# - Keep proposals under 500 words
+# - Always include a "Non-goals" section
+# tasks:
+# - Break tasks into chunks of max 2 hours
diff --git a/openspec/project.md b/openspec/project.md
new file mode 100644
index 00000000..3da5119d
--- /dev/null
+++ b/openspec/project.md
@@ -0,0 +1,31 @@
+# Project Context
+
+## Purpose
+[Describe your project's purpose and goals]
+
+## Tech Stack
+- [List your primary technologies]
+- [e.g., TypeScript, React, Node.js]
+
+## Project Conventions
+
+### Code Style
+[Describe your code style preferences, formatting rules, and naming conventions]
+
+### Architecture Patterns
+[Document your architectural decisions and patterns]
+
+### Testing Strategy
+[Explain your testing approach and requirements]
+
+### Git Workflow
+[Describe your branching strategy and commit conventions]
+
+## Domain Context
+[Add domain-specific knowledge that AI assistants need to understand]
+
+## Important Constraints
+[List any technical, business, or regulatory constraints]
+
+## External Dependencies
+[Document key external services, APIs, or systems]
diff --git a/tools/perf/openai_responses_ws_v2_compare_k6.js b/tools/perf/openai_responses_ws_v2_compare_k6.js
new file mode 100644
index 00000000..6bb4b9a2
--- /dev/null
+++ b/tools/perf/openai_responses_ws_v2_compare_k6.js
@@ -0,0 +1,167 @@
+import http from 'k6/http';
+import { check, sleep } from 'k6';
+import { Rate, Trend } from 'k6/metrics';
+
+const baseURL = (__ENV.BASE_URL || 'http://127.0.0.1:5231').replace(/\/$/, '');
+const httpAPIKey = (__ENV.HTTP_API_KEY || '').trim();
+const wsAPIKey = (__ENV.WS_API_KEY || '').trim();
+const model = __ENV.MODEL || 'gpt-5.1';
+const duration = __ENV.DURATION || '5m';
+const timeout = __ENV.TIMEOUT || '180s';
+
+const httpRPS = Number(__ENV.HTTP_RPS || 10);
+const wsRPS = Number(__ENV.WS_RPS || 10);
+const chainRPS = Number(__ENV.CHAIN_RPS || 1);
+const chainRounds = Number(__ENV.CHAIN_ROUNDS || 20);
+const preAllocatedVUs = Number(__ENV.PRE_ALLOCATED_VUS || 40);
+const maxVUs = Number(__ENV.MAX_VUS || 300);
+
+const httpDurationMs = new Trend('openai_http_req_duration_ms', true);
+const wsDurationMs = new Trend('openai_ws_req_duration_ms', true);
+const wsChainDurationMs = new Trend('openai_ws_chain_round_duration_ms', true);
+const wsChainTTFTMs = new Trend('openai_ws_chain_round_ttft_ms', true);
+const httpNon2xxRate = new Rate('openai_http_non2xx_rate');
+const wsNon2xxRate = new Rate('openai_ws_non2xx_rate');
+const wsChainRoundSuccessRate = new Rate('openai_ws_chain_round_success_rate');
+
+export const options = {
+ scenarios: {
+ http_baseline: {
+ executor: 'constant-arrival-rate',
+ exec: 'runHTTPBaseline',
+ rate: httpRPS,
+ timeUnit: '1s',
+ duration,
+ preAllocatedVUs,
+ maxVUs,
+ tags: { path: 'http_baseline' },
+ },
+ ws_baseline: {
+ executor: 'constant-arrival-rate',
+ exec: 'runWSBaseline',
+ rate: wsRPS,
+ timeUnit: '1s',
+ duration,
+ preAllocatedVUs,
+ maxVUs,
+ tags: { path: 'ws_baseline' },
+ },
+ ws_chain_20_rounds: {
+ executor: 'constant-arrival-rate',
+ exec: 'runWSChain20Rounds',
+ rate: chainRPS,
+ timeUnit: '1s',
+ duration,
+ preAllocatedVUs: Math.max(2, Math.ceil(chainRPS * 2)),
+ maxVUs: Math.max(20, Math.ceil(chainRPS * 10)),
+ tags: { path: 'ws_chain_20_rounds' },
+ },
+ },
+ thresholds: {
+ openai_http_non2xx_rate: ['rate<0.02'],
+ openai_ws_non2xx_rate: ['rate<0.02'],
+ openai_http_req_duration_ms: ['p(95)<4000', 'p(99)<7000'],
+ openai_ws_req_duration_ms: ['p(95)<3000', 'p(99)<6000'],
+ openai_ws_chain_round_success_rate: ['rate>0.98'],
+ openai_ws_chain_round_ttft_ms: ['p(99)<1200'],
+ },
+};
+
+function buildHeaders(apiKey) {
+ const headers = {
+ 'Content-Type': 'application/json',
+ 'User-Agent': 'codex_cli_rs/0.98.0',
+ };
+ if (apiKey) {
+ headers.Authorization = `Bearer ${apiKey}`;
+ }
+ return headers;
+}
+
+function buildBody(previousResponseID) {
+ const body = {
+ model,
+ stream: false,
+ input: [
+ {
+ role: 'user',
+ content: [{ type: 'input_text', text: '请回复一个单词: pong' }],
+ },
+ ],
+ max_output_tokens: 64,
+ };
+ if (previousResponseID) {
+ body.previous_response_id = previousResponseID;
+ }
+ return JSON.stringify(body);
+}
+
+function postResponses(apiKey, body, tags) {
+ const res = http.post(`${baseURL}/v1/responses`, body, {
+ headers: buildHeaders(apiKey),
+ timeout,
+ tags,
+ });
+ check(res, {
+ 'status is 2xx': (r) => r.status >= 200 && r.status < 300,
+ });
+ return res;
+}
+
+function parseResponseID(res) {
+ if (!res || !res.body) {
+ return '';
+ }
+ try {
+ const payload = JSON.parse(res.body);
+ if (payload && typeof payload.id === 'string') {
+ return payload.id.trim();
+ }
+ } catch (_) {
+ return '';
+ }
+ return '';
+}
+
+export function runHTTPBaseline() {
+ const res = postResponses(httpAPIKey, buildBody(''), { transport: 'http' });
+ httpDurationMs.add(res.timings.duration, { transport: 'http' });
+ httpNon2xxRate.add(res.status < 200 || res.status >= 300, { transport: 'http' });
+}
+
+export function runWSBaseline() {
+ const res = postResponses(wsAPIKey, buildBody(''), { transport: 'ws_v2' });
+ wsDurationMs.add(res.timings.duration, { transport: 'ws_v2' });
+ wsNon2xxRate.add(res.status < 200 || res.status >= 300, { transport: 'ws_v2' });
+}
+
+// 20+ 轮续链专项,验证 previous_response_id 在长链下的稳定性与时延。
+export function runWSChain20Rounds() {
+ let previousResponseID = '';
+ for (let round = 1; round <= chainRounds; round += 1) {
+ const roundStart = Date.now();
+ const res = postResponses(wsAPIKey, buildBody(previousResponseID), { transport: 'ws_v2_chain' });
+ const ok = res.status >= 200 && res.status < 300;
+ wsChainRoundSuccessRate.add(ok, { round: `${round}` });
+ wsChainDurationMs.add(Date.now() - roundStart, { round: `${round}` });
+ wsChainTTFTMs.add(res.timings.waiting, { round: `${round}` });
+ wsNon2xxRate.add(!ok, { transport: 'ws_v2_chain' });
+ if (!ok) {
+ return;
+ }
+ const respID = parseResponseID(res);
+ if (!respID) {
+ wsChainRoundSuccessRate.add(false, { round: `${round}`, reason: 'missing_response_id' });
+ return;
+ }
+ previousResponseID = respID;
+ sleep(0.01);
+ }
+}
+
+export function handleSummary(data) {
+ return {
+ stdout: `\nOpenAI WSv2 对比压测完成\n${JSON.stringify(data.metrics, null, 2)}\n`,
+ 'docs/perf/openai-ws-v2-compare-summary.json': JSON.stringify(data, null, 2),
+ };
+}
diff --git a/tools/perf/openai_ws_pooling_compare_k6.js b/tools/perf/openai_ws_pooling_compare_k6.js
new file mode 100644
index 00000000..d8210479
--- /dev/null
+++ b/tools/perf/openai_ws_pooling_compare_k6.js
@@ -0,0 +1,123 @@
+import http from 'k6/http';
+import { check } from 'k6';
+import { Rate, Trend } from 'k6/metrics';
+
+const pooledBaseURL = (__ENV.POOLED_BASE_URL || 'http://127.0.0.1:5231').replace(/\/$/, '');
+const oneToOneBaseURL = (__ENV.ONE_TO_ONE_BASE_URL || '').replace(/\/$/, '');
+const wsAPIKey = (__ENV.WS_API_KEY || '').trim();
+const model = __ENV.MODEL || 'gpt-5.1';
+const timeout = __ENV.TIMEOUT || '180s';
+const duration = __ENV.DURATION || '5m';
+const pooledRPS = Number(__ENV.POOLED_RPS || 12);
+const oneToOneRPS = Number(__ENV.ONE_TO_ONE_RPS || 12);
+const preAllocatedVUs = Number(__ENV.PRE_ALLOCATED_VUS || 50);
+const maxVUs = Number(__ENV.MAX_VUS || 400);
+
+const pooledDurationMs = new Trend('openai_ws_pooled_duration_ms', true);
+const oneToOneDurationMs = new Trend('openai_ws_one_to_one_duration_ms', true);
+const pooledTTFTMs = new Trend('openai_ws_pooled_ttft_ms', true);
+const oneToOneTTFTMs = new Trend('openai_ws_one_to_one_ttft_ms', true);
+const pooledNon2xxRate = new Rate('openai_ws_pooled_non2xx_rate');
+const oneToOneNon2xxRate = new Rate('openai_ws_one_to_one_non2xx_rate');
+
+export const options = {
+ scenarios: {
+ pooled_mode: {
+ executor: 'constant-arrival-rate',
+ exec: 'runPooledMode',
+ rate: pooledRPS,
+ timeUnit: '1s',
+ duration,
+ preAllocatedVUs,
+ maxVUs,
+ tags: { mode: 'pooled' },
+ },
+ one_to_one_mode: {
+ executor: 'constant-arrival-rate',
+ exec: 'runOneToOneMode',
+ rate: oneToOneRPS,
+ timeUnit: '1s',
+ duration,
+ preAllocatedVUs,
+ maxVUs,
+ tags: { mode: 'one_to_one' },
+ startTime: '5s',
+ },
+ },
+ thresholds: {
+ openai_ws_pooled_non2xx_rate: ['rate<0.02'],
+ openai_ws_one_to_one_non2xx_rate: ['rate<0.02'],
+ openai_ws_pooled_duration_ms: ['p(95)<3000', 'p(99)<6000'],
+ openai_ws_one_to_one_duration_ms: ['p(95)<6000', 'p(99)<10000'],
+ },
+};
+
+function buildHeaders() {
+ const headers = {
+ 'Content-Type': 'application/json',
+ 'User-Agent': 'codex_cli_rs/0.98.0',
+ };
+ if (wsAPIKey) {
+ headers.Authorization = `Bearer ${wsAPIKey}`;
+ }
+ return headers;
+}
+
+function buildBody() {
+ return JSON.stringify({
+ model,
+ stream: false,
+ input: [
+ {
+ role: 'user',
+ content: [{ type: 'input_text', text: '请回复: pong' }],
+ },
+ ],
+ max_output_tokens: 48,
+ });
+}
+
+function send(baseURL, mode) {
+ if (!baseURL) {
+ return null;
+ }
+ const res = http.post(`${baseURL}/v1/responses`, buildBody(), {
+ headers: buildHeaders(),
+ timeout,
+ tags: { mode },
+ });
+ check(res, {
+ 'status is 2xx': (r) => r.status >= 200 && r.status < 300,
+ });
+ return res;
+}
+
+export function runPooledMode() {
+ const res = send(pooledBaseURL, 'pooled');
+ if (!res) {
+ return;
+ }
+ pooledDurationMs.add(res.timings.duration, { mode: 'pooled' });
+ pooledTTFTMs.add(res.timings.waiting, { mode: 'pooled' });
+ pooledNon2xxRate.add(res.status < 200 || res.status >= 300, { mode: 'pooled' });
+}
+
+export function runOneToOneMode() {
+ if (!oneToOneBaseURL) {
+ return;
+ }
+ const res = send(oneToOneBaseURL, 'one_to_one');
+ if (!res) {
+ return;
+ }
+ oneToOneDurationMs.add(res.timings.duration, { mode: 'one_to_one' });
+ oneToOneTTFTMs.add(res.timings.waiting, { mode: 'one_to_one' });
+ oneToOneNon2xxRate.add(res.status < 200 || res.status >= 300, { mode: 'one_to_one' });
+}
+
+export function handleSummary(data) {
+ return {
+ stdout: `\nOpenAI WS 池化 vs 1:1 对比压测完成\n${JSON.stringify(data.metrics, null, 2)}\n`,
+ 'docs/perf/openai-ws-pooling-compare-summary.json': JSON.stringify(data, null, 2),
+ };
+}
diff --git a/tools/perf/openai_ws_v2_perf_suite_k6.js b/tools/perf/openai_ws_v2_perf_suite_k6.js
new file mode 100644
index 00000000..df700270
--- /dev/null
+++ b/tools/perf/openai_ws_v2_perf_suite_k6.js
@@ -0,0 +1,216 @@
+import http from 'k6/http';
+import { check, sleep } from 'k6';
+import { Rate, Trend } from 'k6/metrics';
+
+const baseURL = (__ENV.BASE_URL || 'http://127.0.0.1:5231').replace(/\/$/, '');
+const wsAPIKey = (__ENV.WS_API_KEY || '').trim();
+const wsHotspotAPIKey = (__ENV.WS_HOTSPOT_API_KEY || wsAPIKey).trim();
+const model = __ENV.MODEL || 'gpt-5.3-codex';
+const duration = __ENV.DURATION || '5m';
+const timeout = __ENV.TIMEOUT || '180s';
+
+const shortRPS = Number(__ENV.SHORT_RPS || 12);
+const longRPS = Number(__ENV.LONG_RPS || 4);
+const errorRPS = Number(__ENV.ERROR_RPS || 2);
+const hotspotRPS = Number(__ENV.HOTSPOT_RPS || 10);
+const preAllocatedVUs = Number(__ENV.PRE_ALLOCATED_VUS || 50);
+const maxVUs = Number(__ENV.MAX_VUS || 400);
+
+const reqDurationMs = new Trend('openai_ws_v2_perf_req_duration_ms', true);
+const ttftMs = new Trend('openai_ws_v2_perf_ttft_ms', true);
+const non2xxRate = new Rate('openai_ws_v2_perf_non2xx_rate');
+const doneRate = new Rate('openai_ws_v2_perf_done_rate');
+const expectedErrorRate = new Rate('openai_ws_v2_perf_expected_error_rate');
+
+export const options = {
+ scenarios: {
+ short_request: {
+ executor: 'constant-arrival-rate',
+ exec: 'runShortRequest',
+ rate: shortRPS,
+ timeUnit: '1s',
+ duration,
+ preAllocatedVUs,
+ maxVUs,
+ tags: { scenario: 'short_request' },
+ },
+ long_request: {
+ executor: 'constant-arrival-rate',
+ exec: 'runLongRequest',
+ rate: longRPS,
+ timeUnit: '1s',
+ duration,
+ preAllocatedVUs: Math.max(20, Math.ceil(longRPS * 6)),
+ maxVUs: Math.max(100, Math.ceil(longRPS * 20)),
+ tags: { scenario: 'long_request' },
+ },
+ error_injection: {
+ executor: 'constant-arrival-rate',
+ exec: 'runErrorInjection',
+ rate: errorRPS,
+ timeUnit: '1s',
+ duration,
+ preAllocatedVUs: Math.max(8, Math.ceil(errorRPS * 4)),
+ maxVUs: Math.max(40, Math.ceil(errorRPS * 12)),
+ tags: { scenario: 'error_injection' },
+ },
+ hotspot_account: {
+ executor: 'constant-arrival-rate',
+ exec: 'runHotspotAccount',
+ rate: hotspotRPS,
+ timeUnit: '1s',
+ duration,
+ preAllocatedVUs: Math.max(16, Math.ceil(hotspotRPS * 3)),
+ maxVUs: Math.max(80, Math.ceil(hotspotRPS * 10)),
+ tags: { scenario: 'hotspot_account' },
+ },
+ },
+ thresholds: {
+ openai_ws_v2_perf_non2xx_rate: ['rate<0.05'],
+ openai_ws_v2_perf_req_duration_ms: ['p(95)<5000', 'p(99)<9000'],
+ openai_ws_v2_perf_ttft_ms: ['p(99)<2000'],
+ openai_ws_v2_perf_done_rate: ['rate>0.95'],
+ },
+};
+
+function buildHeaders(apiKey, opts = {}) {
+ const headers = {
+ 'Content-Type': 'application/json',
+ 'User-Agent': 'codex_cli_rs/0.104.0',
+ 'OpenAI-Beta': 'responses_websockets=2026-02-06,responses=experimental',
+ };
+ if (apiKey) {
+ headers.Authorization = `Bearer ${apiKey}`;
+ }
+ if (opts.sessionID) {
+ headers.session_id = opts.sessionID;
+ }
+ if (opts.conversationID) {
+ headers.conversation_id = opts.conversationID;
+ }
+ return headers;
+}
+
+function shortBody() {
+ return JSON.stringify({
+ model,
+ stream: false,
+ input: [
+ {
+ role: 'user',
+ content: [{ type: 'input_text', text: '请回复一个词:pong' }],
+ },
+ ],
+ max_output_tokens: 64,
+ });
+}
+
+function longBody() {
+ const tools = [];
+ for (let i = 0; i < 28; i += 1) {
+ tools.push({
+ type: 'function',
+ name: `perf_tool_${i}`,
+ description: 'load test tool schema',
+ parameters: {
+ type: 'object',
+ properties: {
+ query: { type: 'string' },
+ limit: { type: 'number' },
+ with_cache: { type: 'boolean' },
+ },
+ required: ['query'],
+ },
+ });
+ }
+
+ const input = [];
+ for (let i = 0; i < 20; i += 1) {
+ input.push({
+ role: 'user',
+ content: [{ type: 'input_text', text: `长请求压测消息 ${i}: 请输出简要摘要。` }],
+ });
+ }
+
+ return JSON.stringify({
+ model,
+ stream: false,
+ input,
+ tools,
+ parallel_tool_calls: true,
+ max_output_tokens: 256,
+ reasoning: { effort: 'medium' },
+ instructions: '你是压测助手,简洁回复。',
+ });
+}
+
+function errorInjectionBody() {
+ return JSON.stringify({
+ model,
+ stream: false,
+ previous_response_id: `resp_not_found_${__VU}_${__ITER}`,
+ input: [
+ {
+ role: 'user',
+ content: [{ type: 'input_text', text: '触发错误注入路径。' }],
+ },
+ ],
+ });
+}
+
+function postResponses(apiKey, body, tags, opts = {}) {
+ const res = http.post(`${baseURL}/v1/responses`, body, {
+ headers: buildHeaders(apiKey, opts),
+ timeout,
+ tags,
+ });
+ reqDurationMs.add(res.timings.duration, tags);
+ ttftMs.add(res.timings.waiting, tags);
+ non2xxRate.add(res.status < 200 || res.status >= 300, tags);
+ return res;
+}
+
+function hasDone(res) {
+ return !!res && !!res.body && res.body.indexOf('[DONE]') >= 0;
+}
+
+export function runShortRequest() {
+ const tags = { scenario: 'short_request' };
+ const res = postResponses(wsAPIKey, shortBody(), tags);
+ check(res, { 'short status is 2xx': (r) => r.status >= 200 && r.status < 300 });
+ doneRate.add(hasDone(res) || (res.status >= 200 && res.status < 300), tags);
+}
+
+export function runLongRequest() {
+ const tags = { scenario: 'long_request' };
+ const res = postResponses(wsAPIKey, longBody(), tags);
+ check(res, { 'long status is 2xx': (r) => r.status >= 200 && r.status < 300 });
+ doneRate.add(hasDone(res) || (res.status >= 200 && res.status < 300), tags);
+}
+
+export function runErrorInjection() {
+ const tags = { scenario: 'error_injection' };
+ const res = postResponses(wsAPIKey, errorInjectionBody(), tags);
+ // 错误注入场景允许 4xx/5xx,重点观测 fallback 和错误路径抖动。
+ expectedErrorRate.add(res.status >= 400, tags);
+ doneRate.add(hasDone(res), tags);
+}
+
+export function runHotspotAccount() {
+ const tags = { scenario: 'hotspot_account' };
+ const opts = {
+ sessionID: 'perf-hotspot-session-fixed',
+ conversationID: 'perf-hotspot-conversation-fixed',
+ };
+ const res = postResponses(wsHotspotAPIKey, shortBody(), tags, opts);
+ check(res, { 'hotspot status is 2xx': (r) => r.status >= 200 && r.status < 300 });
+ doneRate.add(hasDone(res) || (res.status >= 200 && res.status < 300), tags);
+ sleep(0.01);
+}
+
+export function handleSummary(data) {
+ return {
+ stdout: `\nOpenAI WSv2 性能套件压测完成\n${JSON.stringify(data.metrics, null, 2)}\n`,
+ 'docs/perf/openai-ws-v2-perf-suite-summary.json': JSON.stringify(data, null, 2),
+ };
+}
diff --git a/tools/sora-test b/tools/sora-test
new file mode 100755
index 00000000..cb6c2f83
--- /dev/null
+++ b/tools/sora-test
@@ -0,0 +1,192 @@
+#!/usr/bin/env python3
+"""
+Sora access token tester.
+
+Usage:
+ tools/sora-test -at "
"
+"""
+
+from __future__ import annotations
+
+import argparse
+import base64
+import json
+import sys
+import textwrap
+import urllib.error
+import urllib.request
+from dataclasses import dataclass
+from datetime import datetime, timezone
+from typing import Dict, Optional, Tuple
+
+
+DEFAULT_BASE_URL = "https://sora.chatgpt.com"
+DEFAULT_TIMEOUT = 20
+DEFAULT_USER_AGENT = "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)"
+
+
+@dataclass
+class EndpointResult:
+ path: str
+ status: int
+ request_id: str
+ cf_ray: str
+ body_preview: str
+
+
+def parse_args() -> argparse.Namespace:
+ parser = argparse.ArgumentParser(
+ description="Test Sora access token against core backend endpoints.",
+ formatter_class=argparse.RawTextHelpFormatter,
+ epilog=textwrap.dedent(
+ """\
+ Examples:
+ tools/sora-test -at "eyJhbGciOi..."
+ tools/sora-test -at "eyJhbGciOi..." --timeout 30
+ """
+ ),
+ )
+ parser.add_argument("-at", "--access-token", required=True, help="Sora/OpenAI access token (JWT)")
+ parser.add_argument(
+ "--base-url",
+ default=DEFAULT_BASE_URL,
+ help=f"Base URL for Sora backend (default: {DEFAULT_BASE_URL})",
+ )
+ parser.add_argument(
+ "--timeout",
+ type=int,
+ default=DEFAULT_TIMEOUT,
+ help=f"HTTP timeout seconds (default: {DEFAULT_TIMEOUT})",
+ )
+ return parser.parse_args()
+
+
+def mask_token(token: str) -> str:
+ if len(token) <= 16:
+ return token
+ return f"{token[:10]}...{token[-6:]}"
+
+
+def decode_jwt_payload(token: str) -> Optional[Dict]:
+ parts = token.split(".")
+ if len(parts) != 3:
+ return None
+ payload = parts[1]
+ payload += "=" * ((4 - len(payload) % 4) % 4)
+ payload = payload.replace("-", "+").replace("_", "/")
+ try:
+ decoded = base64.b64decode(payload)
+ return json.loads(decoded.decode("utf-8", errors="replace"))
+ except Exception:
+ return None
+
+
+def ts_to_iso(ts: Optional[int]) -> str:
+ if not ts:
+ return "-"
+ try:
+ return datetime.fromtimestamp(ts, tz=timezone.utc).isoformat()
+ except Exception:
+ return "-"
+
+
+def http_get(base_url: str, path: str, access_token: str, timeout: int) -> EndpointResult:
+ url = base_url.rstrip("/") + path
+ req = urllib.request.Request(url=url, method="GET")
+ req.add_header("Authorization", f"Bearer {access_token}")
+ req.add_header("Accept", "application/json, text/plain, */*")
+ req.add_header("Origin", DEFAULT_BASE_URL)
+ req.add_header("Referer", DEFAULT_BASE_URL + "/")
+ req.add_header("User-Agent", DEFAULT_USER_AGENT)
+
+ try:
+ with urllib.request.urlopen(req, timeout=timeout) as resp:
+ raw = resp.read()
+ body = raw.decode("utf-8", errors="replace")
+ return EndpointResult(
+ path=path,
+ status=resp.getcode(),
+ request_id=(resp.headers.get("x-request-id") or "").strip(),
+ cf_ray=(resp.headers.get("cf-ray") or "").strip(),
+ body_preview=body[:500].replace("\n", " "),
+ )
+ except urllib.error.HTTPError as e:
+ raw = e.read()
+ body = raw.decode("utf-8", errors="replace")
+ return EndpointResult(
+ path=path,
+ status=e.code,
+ request_id=(e.headers.get("x-request-id") if e.headers else "") or "",
+ cf_ray=(e.headers.get("cf-ray") if e.headers else "") or "",
+ body_preview=body[:500].replace("\n", " "),
+ )
+ except Exception as e:
+ return EndpointResult(
+ path=path,
+ status=0,
+ request_id="",
+ cf_ray="",
+ body_preview=f"network_error: {e}",
+ )
+
+
+def classify(me_status: int) -> Tuple[str, int]:
+ if me_status == 200:
+ return "AT looks valid for Sora (/backend/me == 200).", 0
+ if me_status == 401:
+ return "AT is invalid or expired (/backend/me == 401).", 2
+ if me_status == 403:
+ return "AT may be blocked by policy/challenge or lacks permission (/backend/me == 403).", 3
+ if me_status == 0:
+ return "Request failed before reaching Sora (network/proxy/TLS issue).", 4
+ return f"Unexpected status on /backend/me: {me_status}", 5
+
+
+def main() -> int:
+ args = parse_args()
+ token = args.access_token.strip()
+ if not token:
+ print("ERROR: empty access token")
+ return 1
+
+ payload = decode_jwt_payload(token)
+ print("=== Sora AT Test ===")
+ print(f"token: {mask_token(token)}")
+ if payload:
+ exp = payload.get("exp")
+ iat = payload.get("iat")
+ scopes = payload.get("scp")
+ scope_count = len(scopes) if isinstance(scopes, list) else 0
+ print(f"jwt.iat: {iat} ({ts_to_iso(iat)})")
+ print(f"jwt.exp: {exp} ({ts_to_iso(exp)})")
+ print(f"jwt.scope_count: {scope_count}")
+ else:
+ print("jwt: payload decode failed (token may not be JWT)")
+
+ endpoints = [
+ "/backend/me",
+ "/backend/nf/check",
+ "/backend/project_y/invite/mine",
+ "/backend/billing/subscriptions",
+ ]
+
+ print("\n--- endpoint checks ---")
+ results = []
+ for path in endpoints:
+ res = http_get(args.base_url, path, token, args.timeout)
+ results.append(res)
+ print(f"{res.path} -> status={res.status} request_id={res.request_id or '-'} cf_ray={res.cf_ray or '-'}")
+ if res.body_preview:
+ print(f" body: {res.body_preview}")
+
+ me_result = next((r for r in results if r.path == "/backend/me"), None)
+ me_status = me_result.status if me_result else 0
+ summary, code = classify(me_status)
+ print("\n--- summary ---")
+ print(summary)
+ return code
+
+
+if __name__ == "__main__":
+ sys.exit(main())
+